#!/usr/bin/python

import getopt
import signal
import sys
import os
import re
import time

if sys.platform =='win32':
  import msvcrt
else:
  import tty
  import termios

###################
# INSTRUCTION-SET #
###################

# a very simple instruction set simulator for the simpleCPUv1d, 
# program file plain text, single step through program.

###################
# INSTRUCTION-SET #
###################

# INSTR   IR15 IR14 IR13 IR12 IR11 IR10 IR09 IR08 IR07 IR06 IR05 IR04 IR03 IR02 IR01 IR00  
# MOVE    0    0    0    0    RD   RD   X    X    K    K    K    K    K    K    K    K
# ADD     0    0    0    1    RD   RD   X    X    K    K    K    K    K    K    K    K
# SUB     0    0    1    0    RD   RD   X    X    K    K    K    K    K    K    K    K
# AND     0    0    1    1    RD   RD   X    X    K    K    K    K    K    K    K    K

# LOAD    0    1    0    0    A    A    A    A    A    A    A    A    A    A    A    A
# STORE   0    1    0    1    A    A    A    A    A    A    A    A    A    A    A    A
# ADDM    0    1    1    0    A    A    A    A    A    A    A    A    A    A    A    A
# SUBM    0    1    1    1    A    A    A    A    A    A    A    A    A    A    A    A

# JUMPU   1    0    0    0    A    A    A    A    A    A    A    A    A    A    A    A
# JUMPZ   1    0    0    1    A    A    A    A    A    A    A    A    A    A    A    A
# JUMPNZ  1    0    1    0    A    A    A    A    A    A    A    A    A    A    A    A
# JUMPC   1    0    1    1    A    A    A    A    A    A    A    A    A    A    A    A 

# CALL    1    1    0    0    A    A    A    A    A    A    A    A    A    A    A    A

# OR      1    1    0    1    RD   RD   X    X    K    K    K    K    K    K    K    K  -- Version 1.2
# XOP1    1    1    1    0    U    U    U    U    U    U    U    U    U    U    U    U  -- NOT IMPLEMENTED

# RET     1    1    1    1    X    X    X    X    X    X    X    X    0    0    0    0
# MOVE    1    1    1    1    RD   RD   RS   RS   X    X    X    X    0    0    0    1
# LOAD    1    1    1    1    RD   RD   RS   RS   X    X    X    X    0    0    1    0  -- REG INDIRECT
# STORE   1    1    1    1    RD   RD   RS   RS   X    X    X    X    0    0    1    1  -- REG INDIRECT   
# ROL     1    1    1    1    RSD  RSD  X    X    X    X    X    X    0    1    0    0  -- Version 1.1

# ROR     1    1    1    1    RSD  RSD  X    X    X    X    X    X    0    1    0    1  -- NOT IMPLEMENTED
# ADD     1    1    1    1    RD   RD   RS   RS   X    X    X    X    0    1    1    0  -- NOT IMPLEMENTED
# SUB     1    1    1    1    RD   RD   RS   RS   X    X    X    X    0    1    1    1  -- NOT IMPLEMENTED
# AND     1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    0    0    0  -- NOT IMPLEMENTED
# OR      1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    0    0    1  -- NOT IMPLEMENTED
# XOR     1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    0    1    0  -- Version 1.1
# ASL     1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    0    1    1  -- Version 1.2

# XOP2    1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    1    0    0  -- NOT IMPLEMENTED REG INDIRECT
# XOP3    1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    1    0    1  -- NOT IMPLEMENTED
# XOP4    1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    1    1    0  -- NOT IMPLEMENTED REG INDIRECT
# XOP5    1    1    1    1    RD   RD   RS   RS   X    X    X    X    1    1    1    1  -- NOT IMPLEMENTED

# .data  IMM

#############
# VARIABLES #
#############

run = False

#############
# FUNCTIONS #
#############

def get_key():
  if sys.platform == 'win32':
    return msvcrt.getch().decode('utf-8')
  else:
    fd = sys.stdin.fileno()
    old_settings = termios.tcgetattr(fd)
    try:
      tty.setraw(sys.stdin.fileno())
      ch = sys.stdin.read(1)
    finally:
      termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
    return ch

# 1000 - 4095
# 0100 - 0999
# 0010 - 0099
# 0000 - 0009

def pad(number):
  if number<10:
    return "000" + str(number)
  elif number>9 and number<100:
    return "00" + str(number)
  elif number>99 and number<1000:
    return "0" + str(number)
  else:
    return str(number)

def convertData(data):
  try:
    if '0b' in data:
      return int(data, 2)    
    elif '0x' in data:
      return int(data,16) 
    else:
      return int(data) 
  except:
    print("Error: invalid data can not convert")
    print(data) 
    sys.exit(1)

def getReg( reg, ra, rb, rc, rd ):
  if reg == "00":
    return ra
  elif reg == "01":
    return rb
  elif reg == "10":
    return rc
  elif reg == "11":
    return rd
  else:
    print("Error: invalid reg: " + str(regx)) 
    sys.exit(1)

def signal_handler(sig, frame):
  global run
  print('\nYou pressed Ctrl+C')
  run = False

################
# MAIN PROGRAM #
################

def simple_cpu_v1d_simulator(argv):
  global run

  signal.signal( signal.SIGINT, signal_handler )

  if len(sys.argv) <= 1:
    print ("Usage: simple_cpu_v1d_simulator.py -i <input_file.dat>")
    print ("                                   -b <label>") 
    print ("                                   -d <0/1/2>") 
    return

  version = 2
  
  s_config = 'i:d:b:'
  l_config = ['input', 'debug', 'breakpoint']

  source = []
  source_filename = ""

  memory = []
  for i in range(0,4096):
    memory.append("0000000000000000")

  debug = False
  debug_level = 0

  input_file_present = False
  number_of_lines = 0

  breakpoint = False
  breakpoint_address = []

  jump = False
  jump_address = 0

  ra = 0
  rb = 0
  rc = 0
  rd = 0

  pc = 0
  z = 0
  c = 0

  prev_pc = 0

  stack = [0,0,0,0]
  stack_pointer = 0

  isa = {"0000XXXX":"move",
         "0001XXXX":"add", 
         "0010XXXX":"sub",
         "0011XXXX":"and", 
         "0100XXXX":"load",
         "0101XXXX":"store", 
         "0110XXXX":"addm",
         "0111XXXX":"subm", 
         "1000XXXX":"jumpu",
         "1001XXXX":"jumpz", 
         "1010XXXX":"jumpnz",
         "1011XXXX":"jumpc", 
         "1100XXXX":"call",
         "1101XXXX":"or", 
         "1110XXXX":"xop1",
         "11110000":"ret", 
         "11110001":"move",
         "11110010":"load", 
         "11110011":"store",
         "11110100":"rol", 
         "11110101":"ror",
         "11110110":"add", 
         "11110111":"sub",
         "11111000":"and", 
         "11111001":"or",
         "11111010":"xor", 
         "11111011":"asl",
         "11111100":"xop2", 
         "11111101":"xop3",
         "11111110":"xop4", 
         "11111111":"xop5" }

  reg = {"00":"ra",
         "01":"rb",
         "10":"rc",
         "11":"rd" }

  try:
    options, remainder = getopt.getopt(sys.argv[1:], s_config, l_config)
  except getopt.GetoptError as m:
    print("Error: invalid arguments -", m)
    sys.exit(1)

  # process arguments #
  # ----------------- #

  for opt, arg in options:
    if opt in ('-i', '--input'):
      if ".dat" in arg:
        source_filename = arg
      else:
        source_filename = arg + ".dat"

      if os.path.isfile(source_filename):
        input_file_present = True
    elif opt in ('-d', '--debug'):
      if int(arg) == 0:
        debug = False
      elif int(arg) == 1:
        debug = True
        debug_level = 1
      else:
        debug = True
        debug_level = 2
    elif opt in ('-b', '--breakpoint'):
      breakpoint = True
      breakpoint_address.append(arg)

  if debug:	  
    print ("read input parameter : OK")
    if debug_level == 2:
      print( str(source_filename) + " " + str(debug) + "\n")

  # read source file #
  # ---------------- #

  if input_file_present:
    try:
      source_file = open(source_filename, "r")
      source = source_file.readlines()
    except IOError: 
      print("Error: could not open source file")
      sys.exit(1) 

  else:
    print("Error: could not find source file")
    sys.exit(1) 

  if debug:	  
    print ("read code : OK")
    if debug_level == 2:
      print ( source )
      print("")

  # read machine code AAAA : DDDDDDDDDDDDDDDD #
  # ----------------------------------------- #

  number_of_lines = 0
  for line in source:
    line = re.sub(r'\s+', ' ', line.replace('\n','').replace('\r','').lower())	
    if line == '': 
      continue
    else:
      try:
        words = line.split(' ')
        memory[int(words[0])] = words[1]
        number_of_lines = number_of_lines + 1
      except IndexError:
        print("Error: invalid address: " + str(words[0])) 
        sys.exit(1)

  if debug:	  
    print ("code processed : OK")
    print ("instructions : " + str(number_of_lines))
    if debug_level == 2:
      for i in range(number_of_lines):
        print( memory[i] )

  # run program #
  # ----------- #

  print("WARNING : this simulator is not guaranteed to be functionally accurate when compared to the HW")
  print("Press CTRL-C to stop simulation run")

  while True:
    instruction = memory[pc]

    regx = instruction[4:6]
    regy = instruction[6:8]

    opcode_high = instruction[0:4]
    if opcode_high == "1111":
      opcode_low  = instruction[12:]
    else:
      opcode_low  = "XXXX"

    opcode = opcode_high + opcode_low

    signed = int((instruction[8]+instruction[8]+instruction[8]+instruction[8]+
                  instruction[8]+instruction[8]+instruction[8]+instruction[8]+instruction[8:]),2)

    unsigned = int(("00000000"+instruction[8:]), 2)
    absolute = int(instruction[4:], 2)

    disassembled_instruction = ""
    jump = False

    if debug:
      print( str(instruction) + " " + str(opcode_high) )
        
    # MOVE #
    # ---- #

    if opcode_high == "0000": 
      if regx == "00":
        ra = signed
      elif regx == "01":
        rb = signed
      elif regx == "10":
        rc = signed
      elif regx == "11":
        rd = signed
      else:
        print("Error: move - invalid regx: " + str(regx)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + str(signed) + " -> " + reg[regx] + ":" + str(signed)

    # ADD #
    # --- #

    elif opcode_high == "0001": 
      if regx == "00":
        tmp = ra + signed
        ra = tmp & 65535 
      elif regx == "01":
        tmp = rb + signed
        rb = tmp & 65535 
      elif regx == "10":
        tmp = rc + signed
        rc = tmp & 65535 
      elif regx == "11":
        tmp = rd + signed
        rd = tmp & 65535 
      else:
        print("Error: add - invalid regx: " + str(regx)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = int(tmp > 65535)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + str(signed) + " -> " + reg[regx] + ":" + str(tmp & 65535) + " z:" + str(z) + " c:" + str(c) 

    # SUB #
    # --- #

    elif opcode_high == "0010": 
      if regx == "00":
        tmp = ra - signed
        ra = tmp & 65535 
      elif regx == "01":
        tmp = rb - signed
        rb = tmp & 65535 
      elif regx == "10":
        tmp = rc - signed
        rc = tmp & 65535 
      elif regx == "11":
        tmp = rd - signed
        rd = tmp & 65535 
      else:
        print("Error: sub - invalid regx: " + str(regx)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = int(tmp > 65535)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + str(signed) + " -> " + reg[regx] + ":" + str(tmp & 65535) + " z:" + str(z) + " c:" + str(c) 

    # AND #
    # --- #

    elif opcode_high == "0011": 
      if regx == "00":
        tmp = ra & unsigned
        ra = tmp
      elif regx == "01":
        tmp = rb & unsigned
        rb = tmp
      elif regx == "10":
        tmp = rc & unsigned
        rc = tmp
      elif regx == "11":
        tmp = rd & unsigned
        rd = tmp
      else:
        print("Error: and - invalid regx: " + str(regx)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = 0
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + str(signed) + " -> " + reg[regx] + ":" + str(tmp) + " z:" + str(z) + " c:" + str(c) 

    # LOAD #
    # ---- #

    elif opcode_high == "0100":
      if absolute <= 4095:
        ra = int( memory[absolute], 2 )
      else:
        print("Error: load - invalid abs: " + str(absolute)) 
        sys.exit(1)

      if debug:	  
        if debug_level == 2:
          print( memory[absolute] )
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " ra " + str(absolute) + " -> ra:" + str(ra)  

    # STORE #
    # ----- #

    elif opcode_high == "0101":
      if absolute <= 4095:
        if ra >= 0 and ra < 65336:
          tmp = bin(ra).split('b')[1]
          for i in range(16 - len(tmp)):
            tmp ="0" + tmp
          memory[absolute] = tmp
        else:
          print("Error: store - invalid ra: " + str(ra)) 
          sys.exit(1)
      else:
        print("Error: store - invalid abs: " + str(absolute)) 
        sys.exit(1)

      if debug:	  
        if debug_level == 2:
          print( memory[absolute] )
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " ra " + str(absolute) 

    # ADDM #
    # ---- #

    elif opcode_high == "0110":
      if absolute <= 4095:
        tmp = ra + int( memory[absolute], 2 )
        ra = tmp & 65535
      else:
        print("Error: addm - invalid abs: " + str(absolute)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = int(tmp > 65535)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " ra " + str(absolute) + " -> ra:" + str(ra) + " z:" + str(z) + " c:" + str(c) 

    # SUBM #
    # ---- #

    elif opcode_high == "0111":
      disassembled_instruction = isa[opcode] + " ra " + str(absolute)
      if absolute <= 4095:
        tmp = ra - int( memory[absolute], 2 )
        ra = tmp & 65535
      else:
        print("Error: subm - invalid abs: " + str(absolute)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = int(tmp > 65535)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " ra " + str(absolute) + " -> ra:" + str(ra) + " z:" + str(z) + " c:" + str(c) 

    # JUMPU #
    # ----- #

    elif opcode_high == "1000":
      if absolute <= 4095:
        jump = True
        jump_address = absolute
      else:
        print("Error: jumpu invalid abs: " + str(absolute)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + str(absolute) + " -> " + str(jump) 

    # JUMPZ #
    # ----- #

    elif opcode_high == "1001":
      if absolute <= 4095:
        if z == 1:
          jump = True
          jump_address = absolute
      else:
        print("Error: jumpz invalid abs: " + str(absolute)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + str(absolute) + " -> " + str(jump) + " z:" + str(z) + " c:" + str(c) 

    # JUMPNZ #
    # ------ #

    elif opcode_high == "1010":
      if absolute <= 4095:
        if z == 0:
          jump = True
          jump_address = absolute
      else:
        print("Error: jumpnz invalid abs: " + str(absolute)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + str(absolute) + " -> " + str(jump) + " z:" + str(z) + " c:" + str(c) 

    # JUMPC #
    # ----- #

    elif opcode_high == "1011":
      if absolute <= 4095:
        if c == 1:
          jump = True
          jump_address = absolute
      else:
        print("Error: jumpc invalid abs: " + str(absolute)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + str(absolute) + " -> " + str(jump) + " z:" + str(z) + " c:" + str(c) 

    # CALL #
    # ---- #

    elif opcode_high == "1100":
      if absolute <= 4095:
        jump = True
        stack[stack_pointer] = pc+1
        jump_address = absolute
        stack_pointer = stack_pointer + 1
        if stack_pointer > 3:
          stack_pointer = 0
      else:
        print("Error: call invalid abs: " + str(absolute)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + str(absolute) + " -> sp:" + str(stack_pointer) 

    # RET #
    # --- #

    elif opcode_high == "1111" and opcode_low == "0000":
      jump = True
      if stack_pointer == 0:
        stack_pointer = 3
      else:
        stack_pointer = stack_pointer - 1

      jump_address = stack[stack_pointer] 
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " -> sp:" + str(stack_pointer) + " " + str(jump_address)

    # MOVE #
    # ---- #

    elif opcode_high == "1111" and opcode_low == "0001":
      tmp = getReg( regy, ra, rb, rc, rd )
      if regx == "00":
        ra = tmp
      elif regx == "01":
        rb = tmp
      elif regx == "10":
        rc = tmp
      elif regx == "11":
        rd = tmp
      else:
        print("Error: move - invalid regx: " + str(regx)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + reg[regy] + " -> " + reg[regx] + ":" + str(tmp)

    # LOAD #
    # ---- #

    elif opcode_high == "1111" and opcode_low == "0010":
      addr = getReg( regy, ra, rb, rc, rd ) & 4095
      tmp  = int( memory[addr], 2 )
      if regx == "00":
        ra = tmp
      elif regx == "01":
        rb = tmp
      elif regx == "10":
        rc = tmp
      elif regx == "11":
        rd = tmp
      else:
        print("Error: load - invalid regx: " + str(regx)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " (" + reg[regy] + ") -> " + reg[regx] + ":" + str(tmp)

    # STORE #
    # ----- #

    elif opcode_high == "1111" and opcode_low == "0011":
      addr = getReg( regy, ra, rb, rc, rd ) & 4095
      data = getReg( regx, ra, rb, rc, rd )
      if data >= 0 and data < 65336:
        tmp = bin(data).split('b')[1]
        for i in range(16 - len(tmp)):
          tmp ="0" + tmp
        memory[addr] = tmp
      else:
        print("Error: store - invalid ra: " + str(ra)) 
        sys.exit(1)
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " (" + reg[regy] + ")" 

    # ROL #
    # --- #

    elif opcode_high == "1111" and opcode_low == "0100":
      tmp = bin(getReg( regy, ra, rb, rc, rd ))[2:].zfill(16)
      tmp = int((tmp[1:] + tmp[:1]), 2)
      if regx == "00":
        ra = tmp
      elif regx == "01":
        rb = tmp
      elif regx == "10":
        rc = tmp
      elif regx == "11":
        rd = tmp
      else:
        print("Error: rol - invalid regx: " + str(regx)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = 0
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " -> " + reg[regx] + ":" + str(tmp)
 
    # XOR #
    # --- #

    elif opcode_high == "1111" and opcode_low == "1010":
      if regx == "00":
        tmp = ra ^ getReg( regy, ra, rb, rc, rd )
        ra = tmp
      elif regx == "01":
        tmp = rb ^ getReg( regy, ra, rb, rc, rd )
        rb = tmp
      elif regx == "10":
        tmp = rc ^ getReg( regy, ra, rb, rc, rd )
        rc = tmp
      elif regx == "11":
        tmp = rd ^ getReg( regy, ra, rb, rc, rd )
        rd = tmp
      else:
        print("Error: xor - invalid regx: " + str(regx)) 
        sys.exit(1)
      z = int(tmp == 0)
      c = 0
      disassembled_instruction = str(pad(pc)) + ": " + isa[opcode] + " " + reg[regx] + " " + reg[regy] + " -> " + reg[regx] + ":" + str(tmp)

    else:
      print("Error: invalid opcode: " + str(opcode_high) + " " + str(opcode_low) ) 
      sys.exit(1)
   
    # update line counter #
    # ------------------- #

    print( disassembled_instruction )

    prev_pc = pc

    if jump:
      pc = jump_address
    else:
      pc = pc + 1

    if not run: 
      print("    : r=run, s=step, v=registers, x=read, w=write, q=quit")
      while True:
        key = get_key()
        if key == 'q' or prev_pc == pc:
          print( "\nra:" + str(ra) +  " rb:" + str(rb) + " rc:" + str(rc) + " rd:" + str(rd) + " sp:" + str(stack_pointer) )

          for row in range(0, 4096, 16):
            block = memory[row:row + 16]
            if all(val == "0000000000000000" for val in block):
              continue

            line = f"{row:02x}:"  # Row label
            for offset in range(16):
              addr = row + offset
              data = int(memory[addr], 2)
              line += f" {data:04x}"
            print(line)
          return
        elif key == 'r':
          run = True
          break
        elif key == 's':
          break
        elif key == 'v':
          print( "    : ra:" + str(ra) +  " rb:" + str(rb) + " rc:" + str(rc) + " rd:" + str(rd) + " sp:" + str(stack_pointer) + " z:" + str(z) + " c:" + str(c) )
        elif key == 'x':
          address = input("    : Enter address - ")
          if address:
            addr = convertData( address )
            if addr >=0 and addr < 4096:
              print( "   : " + str(memory[addr]) + " - " + str( int(memory[addr], 2) ) )
            else:
              print( "   : Invalid address" )
        elif key == 'w':
          address = input("    : Enter address - ")
          data = input("    : Enter data - ")

          if address:
            addr = convertData( address )
            if addr >=0 and addr < 4096:
              if data:
                value = convertData( data )
                if value >= 0 and value < 65336:
                  tmp = bin(value).split('b')[1]
                  for i in range(16 - len(tmp)):
                    tmp ="0" + tmp
                  memory[addr] = tmp

    else:
      time.sleep(0.25)
      if breakpoint:
        for addr in breakpoint_address:
          if pc == int(addr):
            run = False
      if prev_pc == pc:
        run = False
   
if __name__ == '__main__':
  simple_cpu_v1d_simulator(sys.argv)


