#!/usr/bin/env python
"""
pax -- a python interpreter for ax (aka CSC258 assembler).
See http://wolever.net/~wolever/ax for the most recent source
See http://www.cs.utoronto.ca/~hehner/csc258/ for the original ax

Copyright 2007, David Wolever <david@wolever.net>
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
version 2, as published by the Free Software Foundation.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.
"""

from sys import stdin, stdout, argv, exit
from os import system
from optparse import OptionParser
import re

opcodes = ["LDA", "STA", "ADD", "SUB", "MUL", "DIV", "MOD",
           "FLA", "FLS", "FLM", "FLD", "CIF", "CFI", "AND",
           "IOR", "XOR", "BUN", "BZE", "BSA", "BIN", "INP", "OUT"]

ascii_codes = ["NUL", "SOH", "STX", "ETX", "EOT", "ENQ", "ACK", "BEL", "BS", "HT",
               "LF", "VT", "FF", "CR", "SO", "SI", "DLE", "DC1", "DC2", "DC3",
               "DC4", "NAK", "SYN", "ETB", "CAN", "EM", "SUB", "ESC", "FS", "GS",
               "RS", "US", "SP", "!", "\"", "#", "$", "%", "&", "'",
               "(", ")", "*", "+", ",", "-", ".", "/", "0", "1",
               "2", "3", "4", "5", "6", "7", "8", "9", ":", ";",
               "<", "=", ">", "?", "@", "A", "B", "C", "D", "E",
               "F", "G", "H", "I", "J", "K", "L", "M", "N", "O",
               "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y",
               "Z", "[", "\\", "]", "^", "_", "`", "a", "b", "c",
               "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
               "n", "o", "p", "q", "r", "s", "t", "u", "v", "w",
               "x", "y", "z", "{", "|", "}", "~", "DEL"]

def chtostr(ch):
    if ch >= len(ascii_codes): return "%c (0x%04X)" %(ch, ch)
    return "%s (0x%04X)" %(ascii_codes[ch], ch)

class ParseError(Exception):
    def __init__(self, message, lineno=None, line=""):
        self.lineno = lineno
        if len(line) > 25:
            line = line[:23] + "..."
        self.line = line
        self.m = message

    def __str__(self):
        if self.lineno:
            message = self.m
            if message:
                message = ": " + message
            return "Error parsing line %d (%s)%s" %(self.lineno, self.line, message)
        else:
            return self.m

class SegFault(Exception):
    pass

class LabelError(Exception):
    pass

def trace(str):
    if opts.tracing or opts.debug: print str

def print_ram():
    print format_ram(0, len(ram))

def format_ram(start=0, end=-1):
    """ Tries to intelligently format the contents of RAM[start:end]
        If the offset corresponding to a particular line of source code
        has not changed, the value (along with the line) is printed.
        If the value has changed, the line is prefixed with a bang (!) and
        the instruction at that location is guessed """
    
    r = []
    if end==-1: # For some reason this doesn't work in the function args
        end = len(ram)
    for c in xrange(start, end):
        if offset_has_changed(c):
            m = guess_source(ram[c], c)
        else:
            m = offset_to_source(c, guess=True)
        r.append("%03X: %08X %s" %(c, ram[c], m))
    return "\n".join(r)

def offset_has_changed(o):
    """ Returns True if the instruction at offset o does not correspond
        to the original line of source code (or there is no corresponding
        line of source code. """
    m = offset_to_source(o)
    if not m: return True # The line never existed
    orig_op = parse_line(m, -1, lambda *args: None)[0]
    orig_op = fix_labels(orig_op)
    if ram[o] != orig_op: return True
    else:                 return False

def guess_source(instruction, offset=None):
    """ Given an instruction, guess what line of source code would
        produce it.  The line is prefixed with a bang (!).
        If an offset is given, the label for that offset will be used. """
    label = ""
    if offset != None:
        label = offset_to_label(offset)
        if label: label += ":"
    return "!%7s%s" %(label, binary_to_instruction(instruction))

def add_label(label, offset, lineno):
    """ Adds a label to the global list of labels """
    if label_exists(label):
        raise LabelError("Duplicate definition of label %s" %(label))
    labels[label] = (offset, lineno)

def label_exists(l):
    return labels.get(l, None)

def offset_to_label(o):
    """ Returns the label for a given offset, or the empty string. """
    # This is slow... But, you can't win 'em all
    # (with out a relational database, at least)
    for (label, (offset, line)) in labels.items():
        if o == offset:
            return label
    return ""

def offset_to_source(o, guess=False):
    """ Returns the original line of source code for a given offset.
        If the offset does not correspond to a line in source code and
        if guess is True, a guess will be made as to what the line should be.
        If guess if False, the empty string is returned. """
    source = ""
    if unparsed_ram.has_key(o): source = unparsed_ram[o]
    elif guess: source = guess_source(ram[o], o)
    return source

def label_to_offset(l):
    """ Returns the offset for a label, or raises a LabelError exception. """
    try:
        return labels[l][0]
    except:
        raise LabelError("Label %s is not defined." %l)

def parse_data_alloc(type, m):
    """ Given a type and an argument, returns a list which
        contains an int of that type, or in the case of 'W',
        a list of "uninitialized" values """
    if type == 'I':
        d = int(m) # We need to fake twos complement
        if d < 0: d = 0xFFFFFFFF + d + 1
    elif type == 'F': raise Exception("Sorry, floating point things have not been implemented yet")
    elif type == 'C':
        d = 0
        for c in range(len(m)):
            d += ord(m[c]) << (c * 8)
    elif type == 'B': d = int(m, 2)
    elif type == 'H': d = int(m, 16)
    elif type == 'W': return [0 for x in xrange(int(m))]
    if d > 0xFFFFFFFF or d < -0xFFFFFFFF:
        raise ParseError("Data value too large: 0x%X" %d)
    return [d]

def binary_to_instruction(b):
    opcode = b / 0x01000000
    m = b & 0x00FFFFFF
    if opcode < len(opcodes):
        m_out = offset_to_label(m)
        if not m_out: m_out = "0x%08X" %(m)
        return "%s %s" %(opcodes[opcode], m_out)
    return b

def instruction_to_binary(opcode, m):
    assert opcode in opcodes, "Invalid opcode: %s" %opcode
    if m.isdigit():
        m = int(m)
    else:
        m = label_to_offset(m)

    # See ax.c line 92
    return (opcodes.index(opcode) * 0x01000000) | m

_line_regex = re.compile(r"""
(?: (?:(?P<label>[a-z0-9_]+):)?\s*               # The line's label
    (?:(?:(?P<data_type>[A-Z])                   # The data definition...
          (?:\s*(?P<qt>')|\s+)                   # Which is prefixed by a space or a quote
          (?P<data_value>.+?)                    # Then the data
          (?(qt)'|(?=\s|\Z)))                    # And it ends in a quote, space or line ending
    |(?P<op>[A-Z][A-Z]+)\s+(?P<m>[A-Za-z0-9_]+)) # Otherwise, the opcode, m
| \#)(?P<comment>.*)                             # and a comment, if it exists
""", re.IGNORECASE | re.VERBOSE)
def tokenize_line(line, lineno):
    """ Return a list [label, data type, data value, opcode, m, comment]
        (where data type is something like "W" or "C" and data value is the value).
        If the line does not contain the particular element, the value is "".
        Will raise an exception if the line is invalid (ie: it does not have the
        correct elements or it is blank). """
    try:
        matches = _line_regex.search(line)
        if not matches: raise ParseError("Incorrect format: error matching")
        bits = matches.groupdict()
        ordered_bits = []
        for label in ('label', 'data_type', 'data_value', 'op', 'm', 'comment'):
            ordered_bits.append(bits[label] or "")
        bits = ordered_bits
        if not ((bits[1] and bits[1]) or (bits[3] and bits[4]) or bits[5]):
            raise ParseError("Incorrect format: Incorrect elements")
        if bits[2].count("'") % 2 != 0:
            raise ParseError("Incorrect format: Unmatched quote")
        return bits
    except ParseError, e:
        raise ParseError(str(e), lineno, line)

def test_tokenizer():
    # The first element of the list is the line to parse,
    #     the next n elements are the expected parsed value
    valid_lines = [
           #[line, label, data type, data value, operation, m, comment]
            ["OP m", "", "", "", "OP", "m", ""],
            ["l:OP m", "l", "", "", "OP", "m", ""],
            ["\tl: OP m", "l", "", "", "OP", "m", ""],
            ["W 0 comment", "", "W", "0", "", "", " comment"],
            ["\tOP m xyz", "", "", "", "OP", "m", " xyz"],
            ["W' ' stuff", "", "W", " ", "", "", " stuff"],
            ["x:W  ' '", "x", "W", " ", "", "", ""],
            ["x: W' '", "x", "W", " ", "", "", ""],
            ["F 123 ", "", "F", "123", "", "", " "],
            ["\tx:  I 1234", "x", "I", "1234", "", "", ""],
            ["x:I'1234' asdf", "x", "I", "1234", "", "", " asdf"],
            ["x:F\t-1.2e3 asdf", "x", "F", "-1.2e3", "", "", " asdf"],
            ["zam:B'10101'thing", "zam", "B", "10101", "", "", "thing"],
            ["# comment", "", "", "", "", "", " comment"],
    ]

    # Lines that shouldn't parse
    invalid_lines = [
            "a",
            "line: asdf",
            "T'data",
            "F xd'"
    ]

    for line in valid_lines:
        tokens = tokenize_line(line[0], -1)
        assert tokens == line[1:], "'%s': %s != %s" %(line[0], tokens, line[1:])

    for line in invalid_lines:
        try:
            tokens = tokenize_line(line, -1)
            raise Exception("Invalid line parsed: %s %s" %(line, tokens))
        except ParseError, e: pass
test_tokenizer()

def parse_line(line, lineno, label_handler = add_label):
    """ Takes a line and returns a list containing either
        constants or "opcode" m pairs.
        Calls label_handler on any labels it comes across """
    try:
        (label, data_type, data_value, op, m, comment) = tokenize_line(line, lineno)
        if label: label_handler(label, len(ram), lineno)
        if op in opcodes:
            if not m: raise ParseError("opcode without argument.")
            # Just append the command to the RAM for now
            #      labels will be fixed up later
            return ["%s %s" %(op, m)]
        elif data_type == 'A':
            return ["%s %s" %(data_type, data_value)]
        elif data_type in ('I', 'F', 'C', 'B', 'H', 'W'):
            try:
                return parse_data_alloc(data_type, data_value)
            except (TypeError, ValueError), e:
                raise ParseError("bad data format (%s)" %(str(e)))
        else:
            raise ParseError("bad opcode or data prefix.")
    except Exception, e:
        raise ParseError(str(e), lineno)

def parse(s):
    """ Parses a string (which is, presumably, an ax program)
        into RAM.  The local state is also re-set. """
    global ram
    global unparsed_ram
    global labels

    labels = {'opsys': (0xFACE, -1)} # A list of (label name, offset, line number) tuples
    unparsed_ram = {}
    ram = []

    lineno = 0
    for line in s.split("\n"):
        lineno += 1
        if not line.split() or line[0] == "#":
            continue # Ignore blank lines and comments

        unparsed_ram[len(ram)] = line
        ram.extend(parse_line(line, lineno))

    # Now it's time to go through the RAM and convert all the
    # opcode/label pairs to their binary representations
    for c in xrange(len(ram)):
        ram[c] = fix_labels(ram[c])

def fix_labels(line):
    """ Returns the binary representation of a given line of source,
        with label addresses corrected for and opcodes parsed.
        The name is a bit deceiving because it actually translates
        the opcodes as well... But I can't think of anything better. """
    # If it's already corrected, continue
    if type(line) != str: return line
    (op, m) = line.split()
    if op == "A": return label_to_offset(m)
    else: return instruction_to_binary(op, m)

def run():
    pc = label_to_offset("main") 
    ac = 0
    e = 0

    while pc != 0xFACE:
        m = ram[pc] & 0x00FFFFFF
        op = ram[pc] / 0x01000000
        op = opcodes[op]
        
        if len(ram) < m:
            ram_addr = "m: 0x%08X         " %m
        else:
            ram_addr = "ram[0x%03X]: 0x%08X" %(m, ram[m])

        trace("AC: 0x%08X %s E: %d %s " %(ac, ram_addr, e, format_ram(pc, pc+1)))
        if opts.debug and offset_to_source(pc).find("BREAK") > 0:
            print "BREAK: Press any key to continue..."
            stdin.read(1)

        if op == "LDA": ac = ram[m]
        elif op == "STA": ram[m] = ac
        elif op == "ADD":
            e = 0 # This is a bug-for-bug duplication of ax.c...
            ac += ram[m]
        elif op == "SUB": e = 0; ac -= ram[m]
        elif op == "MUL": e = 0; ac *= ram[m]
        elif op == "DIV":
            if ram[m]==0: e = 1
            else: e = 0; ac /= ram[m]
        elif op == "MOD":
            if (ram[m]==0): e = 1
            else: e = 0; ac %= ram[m]
        elif op == "AND": ac &= ram[m]; e = ac != 0 and 1 or 0
        elif op == "IOR": ac |= ram[m]; e = ac != 0 and 1 or 0
        elif op == "XOR": ac ^= ram[m]; e = ac != 0 and 1 or 0
        elif op == "BUN": pc = m - 1 # Because pc will be increased at the end
        elif op == "BZE":
            if not e: pc = m - 1
        elif op == "BSA": ram[m] = pc + 1; pc = m
        elif op == "BIN": pc = (ram[m] & 0x00FFFFFF) - 1
        elif op == "INP":
            ac = ord(stdin.read(1))
            trace("Read: %s" %chtostr(ac))
        elif op == "OUT":
            stdout.write(chr(ac % 256))
            trace("\rWrote: %s" %chtostr(ac))
        else: raise Exception("Sorry, the instruction %s has not been implemented yet" %op)
        ac = ac & 0xFFFFFFFF # Make sure the AC doesn't get too big
        pc += 1

def pretty_print(file_name):
    out = []
    file = []
    for line in open(file_name).xreadlines():
        file.append(line)
        if line: (label, data_type, data_value, op, m, comment) = tokenize_line(line, len(file))
        else: print; break
        
        if label: label += ":"
        else: label = ''

        if data_type and data_value:
            if data_type == "C": data = "C'%s'" %(data_value)
            else:                data = "%s %s" %(data_type, data_value)
        else:
            data = None

        if not (data_type or op): format = "#%(comment)s"
        elif data_type and data_value == "0" and not label: format = "%(data)s %(comment)s"
        elif data: format = "%(label)-8s%(data)s%(comment)s"
        else: format = "%(label)-8s%(op)s %(m)s%(comment)s"
        out.append(format %(locals()))

    # Compare the pretty-printed code with the original
    #  then explode if they aren't identical
    try:
        parse("\n".join(out))
    except:  # I hope this never happens...
        pass # But if it does, the code below will catch it
    prettifed_ram = ram[:]
    parse("\n".join(file))
    
    if ram != prettifed_ram:
        print "Oh no! I broke! The parsed version of the prettifed file doesn't match the original file!"
        print "Please email the code that broke it to david@wolever.net"
        exit(1)
    print "\n".join(out) 

def parse_args():
    global opts
    parser = OptionParser(usage = "usage: %prog [options] FILE\nAny line with the word BREAK on it is considered a breakpoint.\nUse -h for help")
    def add_flag(short, long, help): parser.add_option("-%s" %short, "--%s" %long, action="store_true", dest=long, help=help)
    add_flag("t", "tracing", "Enable tracing")
    add_flag("d", "debug", "Enable debugger breakpoints (implies -t)")
    add_flag("p", "prettify", "Pretty-print the code (but do not run it)")
    (opts, file) = parser.parse_args()

    if len(file) > 1:
        parser.error("too many files specified")
    if len(file) == 0:
        parser.error("no input file specified")

    return file[0]

def echo(status):
    if status:
        system("stty -cbreak echo")
    else:
        system("stty cbreak -echo")

if __name__ == "__main__":
    try:
        echo(False)
        file = parse_args()

        if opts.prettify:
            pretty_print(file)
            exit(0)

        parse(open(file).read())
        trace(format_ram(0, label_exists("_data") and label_to_offset("_data") or -1))
        run()
    finally:
        echo(True)
