from capstone import *
from capstone.x86 import *

f = open("chal", "rb")
f.seek(0x137c)

md = Cs(CS_ARCH_X86, CS_MODE_64)
md.detail = True
b = f.read(5374)

inp_mapping = {}
inp_mapping[7] = 7 # rbx
inp_mapping[8] = 8 # r13
inp_mapping[-664] = 35 # this was done to piss me off specifically

n = 0

def big_func(addr, idx):
    f.seek(addr-0x400000)
    b_func = f.read(0x40000)
    mul = False
    n_cnt = 0
    if b_func[0] == 0x53: mul = True
    for ins in md.disasm(b_func, addr):
        if mul:
            if "r15" in ins.op_str: n_cnt += 1
        else:
            if "r15" in ins.op_str:
                eax = ins.bytes[0]
                esi = len(ins.bytes)
                rdx = eax % esi
                eax = ins.bytes[rdx]
                if eax & 4 == 0:
                    n_cnt -= 1
                else:
                    n_cnt += 1
        if ins.mnemonic == 'ret':
            if mul: 
                print(f"assign(flag, {idx}, flag[{idx}] * {n_cnt-1}, s)")
            else:
                print(f"assign(flag, {idx}, flag[{idx}] + {n_cnt}, s)")
            return

regs = {}
regs['rbx'] = 7
regs['r13'] = 8

for ins in md.disasm(b, 0x40137c):
    if ins.mnemonic == 'mov':
        i = ins.operands
        if len(i) == 2:
            if i[0].type == X86_OP_REG and i[1].type == X86_OP_MEM:
                if ins.reg_name(i[1].mem.base) == 'rbp':
                    disp = i[1].value.mem.disp - 8
                    
                    if ins.address < 0x40160a:
                        if disp not in inp_mapping:
                            inp_mapping[disp] = n
                            n += 1
                            if n == 7: n = 9
                    
                    regs[ins.reg_name(i[0].value.reg)] = disp
            
            elif i[0].type == X86_OP_REG and i[1].type == X86_OP_REG:
                dst = ins.reg_name(i[0].value.reg)
                src = ins.reg_name(i[1].value.reg)
                if src in regs:
                    regs[dst] = regs[src]
                    
    elif ins.mnemonic == 'call' and ins.address >= 0x40160a:
        if len(ins.operands) > 0 and ins.operands[0].type == X86_OP_IMM:
            print(f"# {hex(ins.address)}")
            imm = ins.operands[0].value.imm
            
            if imm == 0x403750: 
                a = regs.get('rdi')
                b = regs.get('rsi')
                i0 = inp_mapping[a]
                i1 = inp_mapping[b]
                print(f"assign(flag, {i0}, flag[{i0}] - flag[{i1}], s)")

            elif imm == 0x403730: 
                a = regs.get('rdi')
                b = regs.get('rsi')
                i0 = inp_mapping[a]
                i1 = inp_mapping[b]
                print(f"assign(flag, {i0}, flag[{i0}] + flag[{i1}], s)")
                        
            else: 
                a = regs.get('rdi')
                idx = inp_mapping.get(a)
                big_func(imm, idx)

f.close()