import argparse
import binascii
import collections
import os
import re
import shutil
import sys

Args = collections.namedtuple(
    "args", "mode input_file output_file log_folder b_payload_data b_header_digest b_aes_key b_signature b_header_id less_data"
)


ERROR_MESSAGE = str()


class ImgType:
    payload_image = 0
    aes_key_image = 1
    root_cert = 2
    not_root_cert = 3
    payload_return = 4
    payload_no_exec = 5


class SBflags:
    encrypted = 4
    is_signed = 6


'''

typedef struct
{
    SBIMAGE_TYPE type : 3;
    uint32_t checksum : 1;
    uint32_t encrypted : 1;
    uint32_t hash_of_encrypted : 1;
    uint32_t is_signed : 1;
    uint32_t :25;
} SBFlags;

typedef struct
{
    uint32_t header_id;
    uint32_t size;
    uint32_t load_addr;
    uint32_t entry_addr;
    SBFlags flags;
    uint32_t key_number;
    uint32_t cert_id;
    uint32_t sign_cert_id;
    uint32_t payload_digest[8];
    uint32_t header_digest[8];
} SBimage;

'''


class SBimage:
    def __init__(self):
        self.header_id = str()
        self.size = str()
        self.load_addr = str()
        self.entry_addr = str()
        self.flags = str()
        self.key_number = str()
        self.cert_id = str()
        self.sign_cert_id = str()
        self.payload_digest = str()
        self.header_digest = str()
        self.signature = str()
        self.data = str()

    def __setitem__(self, item, data):
        setattr(self, item, data)

    def __getitem__(self, item):
        return getattr(self, item)


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--mode', choices=['break', 'create', 'dump'], required=True,
        help='select mode: \n break sbimage and create new one; \n create sbimage from text file; \n dump sbimage;'
    )

    parser.add_argument('-i', '--input_file', required=True, help='path to input file')
    parser.add_argument('-o', '--output_file', help='path to output file')
    parser.add_argument('--log_folder', help='path to log folder')

    parser.add_argument('--b_payload_data', action='store_true', help='break payload data')
    parser.add_argument('--b_header_digest', action='store_true', help='break header digest')
    parser.add_argument('--b_aes_key', action='store_true', help='break aes key')
    parser.add_argument('--b_signature', action='store_true', help='break signature')
    parser.add_argument('--b_header_id', action='store_true', help='break header_id')

    parser.add_argument('--less_data', action='store_true', help='show less data in dump mode')

    raw_args = parser.parse_args()

    if raw_args.log_folder is not None:
        if os.path.exists(raw_args.log_folder):
            shutil.rmtree(raw_args.log_folder)
            os.mkdir(raw_args.log_folder)
        else:
            os.mkdir(raw_args.log_folder)

    if not os.path.isfile(raw_args.input_file):
        raise OSError('Input file not found "{}"'.format(raw_args.input_file))

    return Args(**vars(raw_args))


def reverse_string(input_string):
    reversed_string = ''
    for i in range(0, len(input_string), 8):
        word = input_string[i:i+8]
        reversed_string += ''.join(a + b for a, b in zip(word[-2::-2], word[::-2]))
    return reversed_string


def file_chunk_to_str(data):
    hex_data = binascii.hexlify(data)
    return reverse_string(hex_data)


def align_size(size, alignment):
    if (size % alignment) != 0:
        size = ((size/alignment) + 1) * alignment
    return size


def read_sbimg(image_file):
    sbheader_size = 96
    signature_size = 384 # RSA 3072
    sbimage = list()

    with open(image_file, 'rb') as f:
        while True:
            image_block = SBimage()
            sbheader = f.read(sbheader_size)
            if sbheader == '':
                break
            image_block.header_id = file_chunk_to_str(sbheader[0:4])
            if image_block.header_id != '53424d47':
                global ERROR_MESSAGE
                ERROR_MESSAGE += '{} is not SB header\n'.format(image_block.header_id)
            image_block.size = file_chunk_to_str(sbheader[4:8])
            image_block.load_addr = file_chunk_to_str(sbheader[8:12])
            image_block.entry_addr = file_chunk_to_str(sbheader[12:16])
            image_block.flags = file_chunk_to_str(sbheader[16:20])
            image_block.key_number = file_chunk_to_str(sbheader[20:24])
            image_block.cert_id = file_chunk_to_str(sbheader[24:28])
            image_block.sign_cert_id = file_chunk_to_str(sbheader[28:32])
            image_block.payload_digest = file_chunk_to_str(sbheader[32:64])
            image_block.header_digest = file_chunk_to_str(sbheader[64:96])

            if get_bit_from_flags(image_block.flags, SBflags.is_signed):
                signature = f.read(signature_size)
                image_block.signature = file_chunk_to_str(signature)

            if get_bit_from_flags(image_block.flags, SBflags.encrypted):
                alignment = 16
            else:
                alignment = 4

            data_size = align_size(int(image_block.size, 16), alignment)
            data = f.read(data_size)

            image_block.data = file_chunk_to_str(data)
            sbimage.append(image_block)

    return sbimage


def get_image_type(flags):
    str_buff = ''
    img_type = int(flags, 16) & 0x7
    if img_type == ImgType.payload_image:
        str_buff = 'Payload image'
    elif img_type == ImgType.payload_return:
        str_buff = 'Payload return image'
    elif img_type == ImgType.payload_no_exec:
        str_buff = 'Payload no exec'
    elif img_type == ImgType.aes_key_image:
        str_buff = 'AES key image'
    elif img_type == ImgType.root_cert:
        str_buff = 'Root certificate'
    elif img_type == ImgType.not_root_cert:
        str_buff = 'Non-root certificate'
    return str_buff


def print_sbimg(sbimage, text_file=None, less_data=False):
    str_buff = ''
    for image_block in sbimage:
        str_buff += '' \
                    'Image type: {}\n' \
                    'header_id = {}\n' \
                    'size = {}\n' \
                    'load_addr = {}\n' \
                    'entry_addr = {}\n' \
                    'flags = {}\n' \
                    'key_number = {}\n' \
                    'cert_id = {}\n' \
                    'sign_cert_id = {}\n' \
                    'payload_digest = {}\n' \
                    'header_digest = {}\n' \
                    ''.format(
            get_image_type(image_block.flags),
            image_block.header_id,
            image_block.size,
            image_block.load_addr,
            image_block.entry_addr,
            image_block.flags,
            image_block.key_number,
            image_block.cert_id,
            image_block.sign_cert_id,
            image_block.payload_digest,
            image_block.header_digest
        )

        if image_block.signature != '':
            str_buff += 'signature = {}\n\n'.format(image_block.signature)

        if less_data and len(image_block.data) > 32:
            str_buff += 'data = {}...{}\n\n'.format(image_block.data[:4], image_block.data[-4:])
        else:
            str_buff += 'data = {}\n\n'.format(image_block.data)

    if text_file is not None:
        with open(text_file, 'wb') as f:
            f.write(str_buff)
    else:
        print str_buff


def create_sbimg(input_file):

    sbimage = list()
    image_block = SBimage()
    with open(input_file, 'rb') as file:
        for line in file:
            if line.startswith('Image type'):
                if image_block.header_id != '':
                    sbimage.append(image_block)
                    image_block = SBimage()
            res = re.search(r'(.*)( = )(.*)', line)
            if res:
                image_block[res.group(1)] = res.group(3)
    sbimage.append(image_block)
    return sbimage


def break_data(data):
    data_list = list(data)
    data_list[0] = format(int(data_list[0], 16) ^ 1, 'x')
    data = ''.join(data_list)
    return data


def get_bit_from_flags(flags, bit_number):
    int_flags = int(flags, 16)
    return (int_flags >> bit_number) & 1


def break_sbimg(args, sbimage):
    for image_block in sbimage:
        if int(image_block.flags, 16) & 0x7 == ImgType.aes_key_image:
            if args.b_aes_key:
                image_block.data = break_data(image_block.data)
        elif int(image_block.flags, 16) & 0x7 == ImgType.payload_image:
            if args.b_payload_data:
                image_block.data = break_data(image_block.data)
            if args.b_header_digest:
                image_block.header_digest = break_data(image_block.header_digest)
            if args.b_signature:
                image_block.signature = break_data(image_block.signature)
        if args.b_header_id:
            image_block.header_id = break_data(image_block.header_id)
                
    return sbimage


def write_sbimg(sbimage, out_file):
    buff = ''
    for image_block in sbimage:
        buff += reverse_string(image_block.header_id)
        buff += reverse_string(image_block.size)
        buff += reverse_string(image_block.load_addr)
        buff += reverse_string(image_block.entry_addr)
        buff += reverse_string(image_block.flags)
        buff += reverse_string(image_block.key_number)
        buff += reverse_string(image_block.cert_id)
        buff += reverse_string(image_block.sign_cert_id)
        buff += reverse_string(image_block.payload_digest)
        buff += reverse_string(image_block.header_digest)
        buff += reverse_string(image_block.signature)
        buff += reverse_string(image_block.data)
    byte_buff = bytearray.fromhex(buff)

    with open(out_file, 'wb') as f:
        f.write(byte_buff)


def main():
    args = get_args()
    if args.mode == 'break':
        sbimage = read_sbimg(args.input_file)

        if args.log_folder is not None:
            print_sbimg(sbimage,
                        os.path.join(
                            args.log_folder, os.path.splitext(os.path.basename(args.input_file))[0] + '_in.txt')
                        )

        sbimage = break_sbimg(args, sbimage)

        if args.log_folder is not None:
            print_sbimg(sbimage,
                        os.path.join(
                            args.log_folder, os.path.splitext(os.path.basename(args.output_file))[0] + '_out.txt')
                        )
        write_sbimg(sbimage, args.output_file)
        pass
    elif args.mode == 'create':
        sbimage = create_sbimg(args.input_file)
        print_sbimg(sbimage, os.path.join(
            args.log_folder, os.path.splitext(os.path.basename(args.output_file))[0] + '.txt')
                    )
        write_sbimg(sbimage, args.output_file)
        pass
    elif args.mode == 'dump':
        sbimage = read_sbimg(args.input_file)
        print_sbimg(sbimage, args.output_file, args.less_data)
        pass
    
    if ERROR_MESSAGE:
        print >> sys.stderr, ERROR_MESSAGE


if __name__ == '__main__':
    exit(main())

