Compare commits

..

11 Commits

Author SHA1 Message Date
df8e9f7e83 refactoring: switched to match/case
wow! python actually added switch cases! too bad this is just syntax sugar...
2023-12-09 12:01:04 -06:00
a22aa808e0 lp: added support for OP_TEST 2022-08-26 01:18:24 -05:00
935844f274 more minor refactoring 2022-08-22 00:59:21 -05:00
c37e9a21d8 ld: minor refactoring 2022-08-22 00:54:59 -05:00
34b1ec7285 ld: added LuaDump()
- chunks can now be serialized back into lua bytecode dumps :D
2022-08-22 00:50:08 -05:00
f9f1d4af00 ld: minor refactoring 2022-08-19 15:46:47 -05:00
3be45f156a lp: support OP_CLOSURE, boilerplate function/proto support 2022-08-17 22:14:45 -05:00
b28edcba1d lp: fix isValidLocal() not respecting capitals 2022-08-17 22:14:13 -05:00
bc4e762e26 lp: bug fix (forgot to transfer registers) 2022-08-16 00:26:50 -05:00
19bed999ee lp: added __parseNewTable(), better table pseudo-code 2022-08-16 00:12:26 -05:00
a248cc4807 lp: added NEWTABLE && SETLIST
- tables can now be (mostly) decompiled
- changed 'decompiled source' to 'pseudo-code' since the output doesn't typically match the compiled script source.
- misc. refactoring
2022-08-15 23:30:32 -05:00
5 changed files with 461 additions and 213 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
example.* example.*
__pycache__ __pycache__
NOTES.md

102
README.md
View File

@@ -12,51 +12,95 @@ Lua has a relatively small instruction set (only 38 different opcodes!). This ma
```sh ```sh
> cat example.lua && luac5.1 -o example.luac example.lua > cat example.lua && luac5.1 -o example.luac example.lua
local total = 0 local printMsg = function(append)
local tbl = {"He", "llo", " ", "Wo"}
local str = ""
for i = 0, 9, 1 do for i = 1, #tbl do
total = total + i str = str .. tbl[i]
print(total) end
print(str .. append)
end end
printMsg("rld!")
> python main.py example.luac > python main.py example.luac
example.luac example.luac
==== [[example.lua's constants]] ==== ==== [[example.lua's constants]] ====
0: [NUMBER] 0.0 0: [STRING] rld!
1: [NUMBER] 9.0
2: [NUMBER] 1.0
3: [STRING] print
==== [[example.lua's locals]] ==== ==== [[example.lua's locals]] ====
R[0]: total R[0]: printMsg
R[1]: (for index)
R[2]: (for limit)
R[3]: (for step)
R[4]: i
==== [[example.lua's dissassembly]] ==== ==== [[example.lua's dissassembly]] ====
[ 0] LOADK : R[0] K[0] ; load 0.0 into R[0] [ 0] CLOSURE : R[0] 0 ;
[ 1] LOADK : R[1] K[0] ; load 0.0 into R[1] [ 1] MOVE : 1 0 0 ; move R[0] into R[1]
[ 2] LOADK : R[2] K[1] ; load 9.0 into R[2] [ 2] LOADK : R[2] K[0] ; load "rld!" into R[2]
[ 3] LOADK : R[3] K[2] ; load 1.0 into R[3] [ 3] CALL : 1 2 1 ;
[ 4] FORPREP : R[1] 4 ; [ 4] RETURN : 0 1 0 ;
[ 5] ADD : R[0] R[0] R[4] ; add R[4] to R[0], place into R[0]
[ 6] GETGLOBAL : R[5] K[3] ; move _G["print"] into R[5]
[ 7] MOVE : 6 0 0 ; move R[0] into R[6]
[ 8] CALL : 5 2 1 ;
[ 9] FORLOOP : R[1] -5 ;
[ 10] RETURN : 0 1 0 ;
==== [[example.lua's decompiled source]] ==== ==== [[example.lua's protos]] ====
local total = 0.0
for i = 0.0, 9.0, 1.0 do ==== [['s constants]] ====
total = (total + i)
print(total) 0: [STRING] He
1: [STRING] llo
2: [STRING]
3: [STRING] Wo
4: [STRING]
5: [NUMBER] 1.0
6: [STRING] print
==== [['s locals]] ====
R[0]: append
R[1]: tbl
R[2]: str
R[3]: (for index)
R[4]: (for limit)
R[5]: (for step)
R[6]: i
==== [['s dissassembly]] ====
[ 0] NEWTABLE : 1 4 0 ;
[ 1] LOADK : R[2] K[0] ; load "He" into R[2]
[ 2] LOADK : R[3] K[1] ; load "llo" into R[3]
[ 3] LOADK : R[4] K[2] ; load " " into R[4]
[ 4] LOADK : R[5] K[3] ; load "Wo" into R[5]
[ 5] SETLIST : 1 4 1 ;
[ 6] LOADK : R[2] K[4] ; load "" into R[2]
[ 7] LOADK : R[3] K[5] ; load 1 into R[3]
[ 8] LEN : 4 1 0 ;
[ 9] LOADK : R[5] K[5] ; load 1 into R[5]
[ 10] FORPREP : R[3] 3 ;
[ 11] MOVE : 7 2 0 ; move R[2] into R[7]
[ 12] GETTABLE : R[8] 1 R[6] ;
[ 13] CONCAT : 2 7 8 ; concat 2 values from R[7] to R[8], store into R[2]
[ 14] FORLOOP : R[3] -4 ;
[ 15] GETGLOBAL : R[3] K[6] ; move _G["print"] into R[3]
[ 16] MOVE : 4 2 0 ; move R[2] into R[4]
[ 17] MOVE : 5 0 0 ; move R[0] into R[5]
[ 18] CONCAT : 4 4 5 ; concat 2 values from R[4] to R[5], store into R[4]
[ 19] CALL : 3 2 1 ;
[ 20] RETURN : 0 1 0 ;
==== [[example.lua's pseudo-code]] ====
local printMsg = function(append)
local tbl = {"He", "llo", " ", "Wo", }
local str = ""
for i = 1, #tbl, 1 do
str = str .. tbl[i]
end
print(str .. append)
end end
printMsg("rld!")
``` ```

View File

@@ -6,9 +6,6 @@
An experimental bytecode decompiler. An experimental bytecode decompiler.
''' '''
from operator import concat
from subprocess import call
from xmlrpc.client import Boolean
from lundump import Chunk, Constant, Instruction, Opcodes, whichRK, readRKasK from lundump import Chunk, Constant, Instruction, Opcodes, whichRK, readRKasK
class _Scope: class _Scope:
@@ -30,14 +27,19 @@ class _Line:
self.scope = scope self.scope = scope
def isValidLocal(ident: str) -> bool: def isValidLocal(ident: str) -> bool:
for c in ident: # has to start with an alpha or _
if c not in "abcdefghijklmnopqrstuvwxyz1234567890_": if ident[0] not in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_":
return False
# then it can be alphanum or _
for c in ident[1:]:
if c not in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_":
return False return False
return True return True
class LuaDecomp: class LuaDecomp:
def __init__(self, chunk: Chunk): def __init__(self, chunk: Chunk, headChunk: bool = True, scopeOffset: int = 0):
self.chunk = chunk self.chunk = chunk
self.pc = 0 self.pc = 0
self.scope: list[_Scope] = [] self.scope: list[_Scope] = []
@@ -46,6 +48,8 @@ class LuaDecomp:
self.locals = {} self.locals = {}
self.traceback = {} self.traceback = {}
self.unknownLocalCount = 0 self.unknownLocalCount = 0
self.headChunk = headChunk
self.scopeOffset = scopeOffset # number of scopes this chunk/proto is in
self.src: str = "" self.src: str = ""
# configurations! # configurations!
@@ -55,6 +59,20 @@ class LuaDecomp:
self.__loadLocals() self.__loadLocals()
if not self.headChunk:
functionProto = "function("
# define params
for i in range(self.chunk.numParams):
# add param to function prototype (also make a local in the register if it doesn't exist)
functionProto += ("%s, " if i+1 < self.chunk.numParams else "%s") % self.__makeLocalIdentifier(i)
# mark local as defined
self.__addSetTraceback(i)
functionProto += ")"
self.__startScope(functionProto, 0, len(self.chunk.instructions))
# parse instructions # parse instructions
while self.pc < len(self.chunk.instructions): while self.pc < len(self.chunk.instructions):
self.parseInstr() self.parseInstr()
@@ -63,12 +81,18 @@ class LuaDecomp:
# end the scope (if we're supposed too) # end the scope (if we're supposed too)
self.__checkScope() self.__checkScope()
print("\n==== [[" + str(self.chunk.name) + "'s decompiled source]] ====\n") if not self.headChunk:
self.__endScope()
def getPseudoCode(self) -> str:
fullSrc = ""
for line in self.lines: for line in self.lines:
if self.annotateLines: if self.annotateLines:
print("-- PC: %d to PC: %d" % (line.startPC, line.endPC)) fullSrc += "-- PC: %d to PC: %d\n" % (line.startPC, line.endPC)
print(((' ' * self.indexWidth) * line.scope) + line.src) fullSrc += ((' ' * self.indexWidth) * (line.scope + self.scopeOffset)) + line.src + "\n"
return fullSrc
# =======================================[[ Helpers ]]========================================= # =======================================[[ Helpers ]]=========================================
@@ -121,7 +145,7 @@ class LuaDecomp:
self.src = "" self.src = ""
# walks traceback, if local wasn't set before, the local needs to be defined # walks traceback, if local wasn't set before, the local needs to be defined
def __needsDefined(self, reg) -> Boolean: def __needsDefined(self, reg) -> bool:
for _, trace in self.traceback.items(): for _, trace in self.traceback.items():
if reg in trace.sets: if reg in trace.sets:
return False return False
@@ -147,7 +171,7 @@ class LuaDecomp:
# if the top indx is a local, get it # if the top indx is a local, get it
return self.locals[indx] if indx in self.locals else self.top[indx] return self.locals[indx] if indx in self.locals else self.top[indx]
def __setReg(self, indx: int, code: str) -> None: def __setReg(self, indx: int, code: str, forceLocal: bool = False) -> None:
# if the top indx is a local, set it # if the top indx is a local, set it
if indx in self.locals: if indx in self.locals:
if self.__needsDefined(indx): if self.__needsDefined(indx):
@@ -155,10 +179,9 @@ class LuaDecomp:
else: else:
self.__addExpr(self.locals[indx] + " = " + code) self.__addExpr(self.locals[indx] + " = " + code)
self.__endStatement() self.__endStatement()
elif self.aggressiveLocals: # 'every register is a local!!' elif self.aggressiveLocals or forceLocal: # 'every register is a local!!'
self.__newLocal(indx, code) self.__newLocal(indx, code)
self.__addSetTraceback(indx) self.__addSetTraceback(indx)
self.top[indx] = code self.top[indx] = code
@@ -176,7 +199,6 @@ class LuaDecomp:
return self.locals[indx] return self.locals[indx]
def __newLocal(self, indx: int, expr: str) -> None: def __newLocal(self, indx: int, expr: str) -> None:
# TODO: grab identifier from chunk(?)
self.__makeLocalIdentifier(indx) self.__makeLocalIdentifier(indx)
self.__addExpr("local " + self.locals[indx] + " = " + expr) self.__addExpr("local " + self.locals[indx] + " = " + expr)
@@ -202,12 +224,15 @@ class LuaDecomp:
self.__addExpr("end") self.__addExpr("end")
self.scope.pop() self.scope.pop()
self.__endStatement()
# =====================================[[ Instructions ]]====================================== # =====================================[[ Instructions ]]======================================
def __emitOperand(self, a: int, b: str, c: str, op: str) -> None: def __emitOperand(self, a: int, b: str, c: str, op: str) -> None:
self.__setReg(a, "(" + b + op + c + ")") self.__setReg(a, "(" + b + op + c + ")")
def __compJmp(self, op: str): # handles conditional jumps
def __condJmp(self, op: str, rkBC: bool = True):
instr = self.__getCurrInstr() instr = self.__getCurrInstr()
jmpType = "if" jmpType = "if"
scopeStart = "then" scopeStart = "then"
@@ -230,7 +255,13 @@ class LuaDecomp:
self.__addExpr("%s not " % jmpType) self.__addExpr("%s not " % jmpType)
else: else:
self.__addExpr("%s " % jmpType) self.__addExpr("%s " % jmpType)
self.__addExpr(self.__readRK(instr.B) + op + self.__readRK(instr.C) + " ")
# write actual comparison
if rkBC:
self.__addExpr(self.__readRK(instr.B) + op + self.__readRK(instr.C) + " ")
else: # just testing rkB
self.__addExpr(op + self.__readRK(instr.B))
self.pc += 1 # skip next instr self.pc += 1 # skip next instr
if scopeStart: if scopeStart:
self.__startScope("%s " % scopeStart, self.pc - 1, jmp) self.__startScope("%s " % scopeStart, self.pc - 1, jmp)
@@ -256,102 +287,160 @@ class LuaDecomp:
else: else:
return self.__getReg(rk) return self.__getReg(rk)
# walk & peak ahead NEWTABLE
def __parseNewTable(self, indx: int):
# TODO: parse SETTABLE too?
tblOps = [Opcodes.LOADK, Opcodes.SETLIST]
instr = self.__getNextInstr()
cachedRegs = {}
tbl = "{"
while instr.opcode in tblOps:
if instr.opcode == Opcodes.LOADK: # operate on registers
cachedRegs[instr.A] = self.chunk.getConstant(instr.B).toCode()
elif instr.opcode == Opcodes.SETLIST:
numElems = instr.B
for i in range(numElems):
tbl += "%s, " % cachedRegs[instr.A + i + 1]
del cachedRegs[instr.A + i + 1]
self.pc += 1
instr = self.__getNextInstr()
tbl += "}"
# i use forceLocal here even though i don't know *for sure* that the register is a local.
# this does help later though if the table is reused (which is 99% of the time). the other 1%
# only affects syntax and may look a little weird but is fine and equivalent non-the-less
self.__setReg(indx, tbl, forceLocal=True)
self.__endStatement()
# if we have leftovers... oops, set those
for i, v in cachedRegs.items():
self.__setReg(i, v)
def parseInstr(self): def parseInstr(self):
instr = self.__getCurrInstr() instr = self.__getCurrInstr()
# python, add switch statements *please* match instr.opcode:
if instr.opcode == Opcodes.MOVE: # move is a fake ABC instr, C is ignored case Opcodes.MOVE: # move is a fake ABC instr, C is ignored
# move registers # move registers
self.__setReg(instr.A, self.__getReg(instr.B)) self.__setReg(instr.A, self.__getReg(instr.B))
elif instr.opcode == Opcodes.LOADK: case Opcodes.LOADK:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).toCode()) self.__setReg(instr.A, self.chunk.getConstant(instr.B).toCode())
elif instr.opcode == Opcodes.LOADBOOL: case Opcodes.LOADBOOL:
if instr.B == 0: if instr.B == 0:
self.__setReg(instr.A, "false") self.__setReg(instr.A, "false")
else: else:
self.__setReg(instr.A, "true") self.__setReg(instr.A, "true")
elif instr.opcode == Opcodes.GETGLOBAL: case Opcodes.GETGLOBAL:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).data) self.__setReg(instr.A, self.chunk.getConstant(instr.B).data)
elif instr.opcode == Opcodes.GETTABLE: case Opcodes.GETTABLE:
self.__setReg(instr.A, self.__getReg(instr.B) + "[" + self.__readRK(instr.C) + "]") self.__setReg(instr.A, self.__getReg(instr.B) + "[" + self.__readRK(instr.C) + "]")
elif instr.opcode == Opcodes.SETGLOBAL: case Opcodes.SETGLOBAL:
self.__addExpr(self.chunk.getConstant(instr.B).data + " = " + self.__getReg(instr.A)) self.__addExpr(self.chunk.getConstant(instr.B).data + " = " + self.__getReg(instr.A))
self.__endStatement() self.__endStatement()
elif instr.opcode == Opcodes.SETTABLE: case Opcodes.SETTABLE:
self.__addExpr(self.__getReg(instr.A) + "[" + self.__readRK(instr.B) + "] = " + self.__readRK(instr.C)) self.__addExpr(self.__getReg(instr.A) + "[" + self.__readRK(instr.B) + "] = " + self.__readRK(instr.C))
self.__endStatement() self.__endStatement()
elif instr.opcode == Opcodes.ADD: case Opcodes.NEWTABLE:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " + ") self.__parseNewTable(instr.A)
elif instr.opcode == Opcodes.SUB: case Opcodes.ADD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " - ") self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " + ")
elif instr.opcode == Opcodes.MUL: case Opcodes.SUB:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " * ") self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " - ")
elif instr.opcode == Opcodes.DIV: case Opcodes.MUL:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " / ") self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " * ")
elif instr.opcode == Opcodes.MOD: case Opcodes.DIV:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " % ") self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " / ")
elif instr.opcode == Opcodes.POW: case Opcodes.MOD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " ^ ") self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " % ")
elif instr.opcode == Opcodes.UNM: case Opcodes.POW:
self.__setReg(instr.A, "-" + self.__getReg(instr.B)) self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " ^ ")
elif instr.opcode == Opcodes.NOT: case Opcodes.UNM:
self.__setReg(instr.A, "not " + self.__getReg(instr.B)) self.__setReg(instr.A, "-" + self.__getReg(instr.B))
elif instr.opcode == Opcodes.LEN: case Opcodes.NOT:
self.__setReg(instr.A, "#" + self.__getCurrInstr(instr.B)) self.__setReg(instr.A, "not " + self.__getReg(instr.B))
elif instr.opcode == Opcodes.CONCAT: case Opcodes.LEN:
count = instr.C-instr.B+1 self.__setReg(instr.A, "#" + self.__getReg(instr.B))
concatStr = "" case Opcodes.CONCAT:
count = instr.C-instr.B+1
concatStr = ""
# concat all items on stack from RC to RB # concat all items on stack from RC to RB
for i in range(count): for i in range(count):
concatStr += self.__getReg(instr.B + i) + (" .. " if not i == count - 1 else "") concatStr += self.__getReg(instr.B + i) + (" .. " if not i == count - 1 else "")
self.__setReg(instr.A, concatStr) self.__setReg(instr.A, concatStr)
elif instr.opcode == Opcodes.JMP: case Opcodes.JMP:
pass pass
elif instr.opcode == Opcodes.EQ: case Opcodes.EQ:
self.__compJmp(" == ") self.__condJmp(" == ")
elif instr.opcode == Opcodes.LT: case Opcodes.LT:
self.__compJmp(" < ") self.__condJmp(" < ")
elif instr.opcode == Opcodes.LE: case Opcodes.LE:
self.__compJmp(" <= ") self.__condJmp(" <= ")
elif instr.opcode == Opcodes.CALL: case Opcodes.TEST:
preStr = "" if instr.C == 0:
callStr = "" self.__condJmp("", False)
ident = "" else:
self.__condJmp("not ", False)
case Opcodes.CALL:
preStr = ""
callStr = ""
ident = ""
# parse arguments # parse arguments
callStr += self.__getReg(instr.A) + "(" callStr += self.__getReg(instr.A) + "("
for i in range(instr.A + 1, instr.A + instr.B): for i in range(instr.A + 1, instr.A + instr.B):
callStr += self.__getReg(i) + (", " if not i + 1 == instr.A + instr.B else "") callStr += self.__getReg(i) + (", " if not i + 1 == instr.A + instr.B else "")
callStr += ")" callStr += ")"
# parse return values # parse return values
if instr.C > 1: if instr.C > 1:
preStr = "local " preStr = "local "
for indx in range(instr.A, instr.A + instr.C - 1): for indx in range(instr.A, instr.A + instr.C - 1):
if indx in self.locals: if indx in self.locals:
ident = self.locals[indx] ident = self.locals[indx]
else: else:
ident = self.__makeLocalIdentifier(indx) ident = self.__makeLocalIdentifier(indx)
preStr += ident preStr += ident
# normally setReg() does this # normally setReg() does this
self.top[indx] = ident self.top[indx] = ident
# just so we don't have a trailing ', ' # just so we don't have a trailing ', '
preStr += ", " if not indx == instr.A + instr.C - 2 else "" preStr += ", " if not indx == instr.A + instr.C - 2 else ""
preStr += " = " preStr += " = "
self.__addExpr(preStr + callStr) self.__addExpr(preStr + callStr)
self.__endStatement() self.__endStatement()
elif instr.opcode == Opcodes.RETURN: case Opcodes.RETURN:
self.__endStatement() self.__endStatement()
pass # no-op for now pass # no-op for now
elif instr.opcode == Opcodes.FORLOOP: case Opcodes.FORLOOP:
pass # no-op for now pass # no-op for now
elif instr.opcode == Opcodes.FORPREP: case Opcodes.FORPREP:
self.__addExpr("for %s = %s, %s, %s " % (self.__getLocal(instr.A+3), self.__getReg(instr.A), self.__getReg(instr.A + 1), self.__getReg(instr.A + 2))) self.__addExpr("for %s = %s, %s, %s " % (self.__getLocal(instr.A+3), self.__getReg(instr.A), self.__getReg(instr.A + 1), self.__getReg(instr.A + 2)))
self.__startScope("do", self.pc, instr.B) self.__startScope("do", self.pc, instr.B)
else: case Opcodes.SETLIST:
raise Exception("unsupported instruction: %s" % instr.toString()) # LFIELDS_PER_FLUSH (50) is the number of elements that *should* have been set in the list in the *last* SETLIST
# eg.
# [ 49] LOADK : R[49] K[1] ; load 0.0 into R[49]
# [ 50] LOADK : R[50] K[1] ; load 0.0 into R[50]
# [ 51] SETLIST : 0 50 1 ; sets list[1..50]
# [ 52] LOADK : R[1] K[1] ; load 0.0 into R[1]
# [ 53] SETLIST : 0 1 2 ; sets list[51..51]
numElems = instr.B
startAt = ((instr.C - 1) * 50)
ident = self.__getLocal(instr.A)
# set each index (TODO: make tables less verbose)
for i in range(numElems):
self.__addExpr("%s[%d] = %s" % (ident, (startAt + i + 1), self.__getReg(instr.A + i + 1)))
self.__endStatement()
case Opcodes.CLOSURE:
proto = LuaDecomp(self.chunk.protos[instr.B], headChunk=False, scopeOffset=len(self.scope))
self.__setReg(instr.A, proto.getPseudoCode())
case _:
raise Exception("unsupported instruction: %s" % instr.toString())

View File

@@ -1,7 +1,7 @@
''' '''
l(un)dump.py l(un)dump.py
A Lua5.1 cross-platform bytecode deserializer. This module pulls int and size_t sizes from the A Lua5.1 cross-platform bytecode deserializer && serializer. This module pulls int and size_t sizes from the
chunk header, meaning it should be able to deserialize lua bytecode dumps from most platforms, chunk header, meaning it should be able to deserialize lua bytecode dumps from most platforms,
regardless of the host machine. regardless of the host machine.
@@ -9,11 +9,9 @@
as well as read the lundump.c source file from the Lua5.1 source. as well as read the lundump.c source file from the Lua5.1 source.
''' '''
from multiprocessing.spawn import get_executable
import struct import struct
import array import array
from enum import IntEnum, Enum, auto from enum import IntEnum, Enum, auto
from typing_extensions import Self
class InstructionType(Enum): class InstructionType(Enum):
ABC = auto(), ABC = auto(),
@@ -70,6 +68,8 @@ _RKBCInstr = [Opcodes.SETTABLE, Opcodes.ADD, Opcodes.SUB, Opcodes.MUL, Opcodes.D
_RKCInstr = [Opcodes.GETTABLE, Opcodes.SELF] _RKCInstr = [Opcodes.GETTABLE, Opcodes.SELF]
_KBx = [Opcodes.LOADK, Opcodes.GETGLOBAL, Opcodes.SETGLOBAL] _KBx = [Opcodes.LOADK, Opcodes.GETGLOBAL, Opcodes.SETGLOBAL]
_LUAMAGIC = b'\x1bLua'
# is an 'RK' value a K? (result is true for K, false for R) # is an 'RK' value a K? (result is true for K, false for R)
def whichRK(rk: int): def whichRK(rk: int):
return (rk & (1 << 8)) > 0 return (rk & (1 << 8)) > 0
@@ -152,7 +152,7 @@ class Constant:
self.data = data self.data = data
def toString(self): def toString(self):
return "[" + self.type.name + "] " + str(self.data) return "[%s] %s" % (self.type.name, str(self.data))
# format the constant so that it is parsable by lua # format the constant so that it is parsable by lua
def toCode(self): def toCode(self):
@@ -164,7 +164,7 @@ class Constant:
else: else:
return "false" return "false"
elif self.type == ConstType.NUMBER: elif self.type == ConstType.NUMBER:
return str(self.data) return "%g" % self.data
else: else:
return "nil" return "nil"
@@ -189,6 +189,7 @@ class Chunk:
self.maxStack: int = 0 self.maxStack: int = 0
self.upvalues: list[str] = [] self.upvalues: list[str] = []
self.lineNums: list[int] = []
self.locals: list[Local] = [] self.locals: list[Local] = []
def appendInstruction(self, instr: Instruction): def appendInstruction(self, instr: Instruction):
@@ -200,9 +201,15 @@ class Chunk:
def appendProto(self, proto): def appendProto(self, proto):
self.protos.append(proto) self.protos.append(proto)
def appendLine(self, line: int):
self.lineNums.append(line)
def appendLocal(self, local: Local): def appendLocal(self, local: Local):
self.locals.append(local) self.locals.append(local)
def appendUpval(self, upval: str):
self.upvalues.append(upval)
def findLocal(self, pc: int) -> Local: def findLocal(self, pc: int) -> Local:
for l in self.locals: for l in self.locals:
if l.start <= pc and l.end >= pc: if l.start <= pc and l.end >= pc:
@@ -298,11 +305,7 @@ class LuaUndump:
self.rootChunk: Chunk = None self.rootChunk: Chunk = None
self.index = 0 self.index = 0
@staticmethod def _loadBlock(self, sz) -> bytearray:
def dis_chunk(chunk: Chunk):
chunk.print()
def loadBlock(self, sz) -> bytearray:
if self.index + sz > len(self.bytecode): if self.index + sz > len(self.bytecode):
raise Exception("Malformed bytecode!") raise Exception("Malformed bytecode!")
@@ -310,82 +313,71 @@ class LuaUndump:
self.index = self.index + sz self.index = self.index + sz
return temp return temp
def get_byte(self) -> int: def _get_byte(self) -> int:
return self.loadBlock(1)[0] return self._loadBlock(1)[0]
def get_int32(self) -> int: def _get_uint32(self) -> int:
if (self.big_endian): order = 'big' if self.big_endian else 'little'
return int.from_bytes(self.loadBlock(4), byteorder='big', signed=False) return int.from_bytes(self._loadBlock(4), byteorder=order, signed=False)
else:
return int.from_bytes(self.loadBlock(4), byteorder='little', signed=False)
def get_int(self) -> int: def _get_uint(self) -> int:
if (self.big_endian): order = 'big' if self.big_endian else 'little'
return int.from_bytes(self.loadBlock(self.int_size), byteorder='big', signed=False) return int.from_bytes(self._loadBlock(self.int_size), byteorder=order, signed=False)
else:
return int.from_bytes(self.loadBlock(self.int_size), byteorder='little', signed=False)
def get_size_t(self) -> int: def _get_size_t(self) -> int:
if (self.big_endian): order = 'big' if self.big_endian else 'little'
return int.from_bytes(self.loadBlock(self.size_t), byteorder='big', signed=False) return int.from_bytes(self._loadBlock(self.size_t), byteorder=order, signed=False)
else:
return int.from_bytes(self.loadBlock(self.size_t), byteorder='little', signed=False)
def get_double(self) -> int: def _get_double(self) -> int:
if self.big_endian: order = '>d' if self.big_endian else '<d'
return struct.unpack('>d', self.loadBlock(8))[0] return struct.unpack(order, self._loadBlock(self.l_number_size))[0]
else:
return struct.unpack('<d', self.loadBlock(8))[0]
def get_string(self, size) -> str: def _get_string(self) -> str:
if (size == None): size = self._get_size_t()
size = self.get_size_t() if (size == 0):
if (size == 0): return ""
return ""
return "".join(chr(x) for x in self.loadBlock(size)) # [:-1] to remove the NULL terminator
return ("".join(chr(x) for x in self._loadBlock(size)))[:-1]
def decode_chunk(self) -> Chunk: def decode_chunk(self) -> Chunk:
chunk = Chunk() chunk = Chunk()
chunk.name = self.get_string(None) # chunk meta info
chunk.frst_line = self.get_int() chunk.name = self._get_string()
chunk.last_line = self.get_int() chunk.frst_line = self._get_uint()
chunk.last_line = self._get_uint()
chunk.numUpvals = self.get_byte() chunk.numUpvals = self._get_byte()
chunk.numParams = self.get_byte() chunk.numParams = self._get_byte()
chunk.isVarg = (self.get_byte() != 0) chunk.isVarg = (self._get_byte() != 0)
chunk.maxStack = self.get_byte() chunk.maxStack = self._get_byte()
if (not chunk.name == ""):
chunk.name = chunk.name[1:-1]
# parse instructions # parse instructions
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
chunk.appendInstruction(_decode_instr(self.get_int32())) chunk.appendInstruction(_decode_instr(self._get_uint32()))
# get constants # get constants
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
constant: Constant = None constant: Constant = None
type = self.get_byte() type = self._get_byte()
if type == 0: #nil if type == 0: # nil
constant = Constant(ConstType.NIL, None) constant = Constant(ConstType.NIL, None)
elif type == 1: # bool elif type == 1: # bool
constant = Constant(ConstType.BOOL, (self.get_byte() != 0)) constant = Constant(ConstType.BOOL, (self._get_byte() != 0))
elif type == 3: # number elif type == 3: # number
constant = Constant(ConstType.NUMBER, self.get_double()) constant = Constant(ConstType.NUMBER, self._get_double())
elif type == 4: # string elif type == 4: # string
constant = Constant(ConstType.STRING, self.get_string(None)[:-1]) constant = Constant(ConstType.STRING, self._get_string())
else: else:
raise Exception("Unknown Datatype! [%d]" % type) raise Exception("Unknown Datatype! [%d]" % type)
chunk.appendConstant(constant) chunk.appendConstant(constant)
# parse protos # parse protos
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
chunk.appendProto(self.decode_chunk()) chunk.appendProto(self.decode_chunk())
@@ -393,47 +385,47 @@ class LuaUndump:
# eh, for now just consume the bytes. # eh, for now just consume the bytes.
# line numbers # line numbers
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
self.get_int() self._get_uint()
# locals # locals
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
name = self.get_string(None)[:-1] # local name ([:-1] to remove the NULL terminator) name = self._get_string() # local name
start = self.get_int() # local start PC start = self._get_uint() # local start PC
end = self.get_int() # local end PC end = self._get_uint() # local end PC
chunk.appendLocal(Local(name, start, end)) chunk.appendLocal(Local(name, start, end))
# upvalues # upvalues
num = self.get_int() num = self._get_uint()
for i in range(num): for i in range(num):
self.get_string(None) # upvalue name chunk.appendUpval(self._get_string()) # upvalue name
return chunk return chunk
def decode_rawbytecode(self, rawbytecode): def decode_rawbytecode(self, rawbytecode):
# bytecode sanity checks # bytecode sanity checks
if not rawbytecode[0:4] == b'\x1bLua': if not rawbytecode[0:4] == _LUAMAGIC:
raise Exception("Lua Bytecode expected!") raise Exception("Lua Bytecode expected!")
bytecode = array.array('b', rawbytecode) bytecode = array.array('b', rawbytecode)
return self.decode_bytecode(bytecode) return self.decode_bytecode(bytecode)
def decode_bytecode(self, bytecode): def decode_bytecode(self, bytecode):
self.bytecode = bytecode self.bytecode = bytecode
# aligns index, skips header # aligns index, skips header
self.index = 4 self.index = 4
self.vm_version = self.get_byte() self.vm_version = self._get_byte()
self.bytecode_format = self.get_byte() self.bytecode_format = self._get_byte()
self.big_endian = (self.get_byte() == 0) self.big_endian = (self._get_byte() == 0)
self.int_size = self.get_byte() self.int_size = self._get_byte()
self.size_t = self.get_byte() self.size_t = self._get_byte()
self.instr_size = self.get_byte() # gets size of instructions self.instr_size = self._get_byte() # gets size of instructions
self.l_number_size = self.get_byte() # size of lua_Number self.l_number_size = self._get_byte() # size of lua_Number
self.integral_flag = self.get_byte() self.integral_flag = self._get_byte() # is lua_Number defined as an int? false = float/double, true = int/long/short/etc.
self.rootChunk = self.decode_chunk() self.rootChunk = self.decode_chunk()
return self.rootChunk return self.rootChunk
@@ -444,5 +436,122 @@ class LuaUndump:
return self.decode_rawbytecode(bytecode) return self.decode_rawbytecode(bytecode)
def print_dissassembly(self): def print_dissassembly(self):
LuaUndump.dis_chunk(self.rootChunk) self.rootChunk.print()
class LuaDump:
def __init__(self, rootChunk: Chunk):
self.rootChunk = rootChunk
self.bytecode = bytearray()
# header info
self.vm_version = 0x51
self.bytecode_format = 0x00
self.big_endian = False
# data sizes
self.int_size = 4
self.size_t = 8
self.instr_size = 4
self.l_number_size = 8
self.integral_flag = False # lua_Number is a double
def _writeBlock(self, data: bytes):
self.bytecode += bytearray(data)
def _set_byte(self, b: int):
self.bytecode.append(b)
def _set_uint32(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(4, order, signed=False))
def _set_uint(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(self.int_size, order, signed=False))
def _set_size_t(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(self.size_t, order, signed=False))
def _set_double(self, f: float):
order = '>d' if self.big_endian else '<d'
self._writeBlock(struct.pack(order, f))
def _set_string(self, string: str):
self._set_size_t(len(string)+1)
self._writeBlock(string.encode('utf-8'))
self._set_byte(0x00) # write null terminator
def _dumpChunk(self, chunk: Chunk):
# write meta info
self._set_string(chunk.name)
self._set_uint(chunk.frst_line)
self._set_uint(chunk.last_line)
self._set_byte(chunk.numUpvals)
self._set_byte(chunk.numParams)
self._set_byte(1 if chunk.isVarg else 1)
self._set_byte(chunk.maxStack)
# write instructions
self._set_uint(len(chunk.instructions))
for l in chunk.instructions:
self._set_uint32(_encode_instr(l))
# write constants
self._set_uint(len(chunk.constants))
for constant in chunk.constants:
# write constant data
if constant.type == ConstType.NIL:
self._set_byte(0)
elif constant.type == ConstType.BOOL:
self._set_byte(1)
self._set_byte(1 if constant.data else 0)
elif constant.type == ConstType.NUMBER: # number
self._set_byte(3)
self._set_double(constant.data)
elif constant.type == ConstType.STRING: # string
self._set_byte(4)
self._set_string(constant.data)
else:
raise Exception("Unknown Datatype! [%s]" % str(constant.type))
# write child protos
self._set_uint(len(chunk.protos))
for p in chunk.protos:
self._dumpChunk(p)
# write line numbers
self._set_uint(len(chunk.lineNums))
for l in chunk.lineNums:
self._set_uint(l)
# write locals
self._set_uint(len(chunk.locals))
for l in chunk.locals:
self._set_string(l.name)
self._set_uint(l.start)
self._set_uint(l.end)
# write upvals
self._set_uint(len(chunk.upvalues))
for u in chunk.upvalues:
self._set_string(u)
def _dumpHeader(self):
self._writeBlock(_LUAMAGIC)
# write header info
self._set_byte(self.vm_version)
self._set_byte(self.bytecode_format)
self._set_byte(0 if self.big_endian else 1)
self._set_byte(self.int_size)
self._set_byte(self.size_t)
self._set_byte(self.instr_size)
self._set_byte(self.l_number_size)
self._set_byte(self.integral_flag)
def dump(self) -> bytearray:
self._dumpHeader()
self._dumpChunk(self.rootChunk)
return self.bytecode

7
main.py Normal file → Executable file
View File

@@ -1,3 +1,4 @@
#!/usr/bin/env python3
import sys import sys
import lundump import lundump
import lparser import lparser
@@ -7,4 +8,8 @@ print(sys.argv[1])
chunk = lc.loadFile(sys.argv[1]) chunk = lc.loadFile(sys.argv[1])
lc.print_dissassembly() lc.print_dissassembly()
lp = lparser.LuaDecomp(chunk)
lp = lparser.LuaDecomp(chunk)
print("\n==== [[" + str(chunk.name) + "'s pseudo-code]] ====\n")
print(lp.getPseudoCode())