#!/usr/bin/env python

from socket import *
from select import *
from datetime import datetime, timedelta
import atexit

HTTP_TIMEOUT = timedelta(seconds=0.5)

_poll_stati = \
    [(8, 'POLLERR'),
    (16, 'POLLHUP'),
    (1, 'POLLIN'),
    (32, 'POLLNVAL'),
    (4, 'POLLOUT'),
    (2, 'POLLPRI'),
    (128, 'POLLRDBAND'),
    (64, 'POLLRDNORM'),
    (256, 'POLLWRBAND'),
    (4, 'POLLWRNORM')]

def poll_status(status):
    if not status: return 'None'
    r = []
    for (s, v) in _poll_stati:
        if s & status: r.append(v)
    return ", ".join(r)

def ss(con):
    """ Socket String -- make sockets print pretty. """
    return con.fileno()

fd_to_con = {}
new_con_timeouts = {}
established_cons = {}

def reg(con, poll_flags = POLLIN):
    """ Registers the connection.
        Called on every socket we ever deal with. """
    global fd_to_con
    _p.register(con, poll_flags)
    fd_to_con[con.fileno()] = con
    print "Registered connection %s" %(ss(con))

def close(con):
    """ Close a connection.
        Unregister it from poll, pull it out of the list
        of active connections. """
    print "Closing con %s" %(ss(con))
    del fd_to_con[con.fileno()]
    _p.unregister(con)
    if con in new_con_timeouts:
        del new_con_timeouts[con]
    if con in established_cons:
        teardown_passthrough(con)
    con.close()

# This is the list of connections closed between
# the time the call to poll() is made and the function
# poll() finishes itterating over the list of sockets
# which have somethign exciting happening on them
_closed = []
def teardown_passthrough(con):
    """ Tears down both ends of the passthrough which
        con is a part of and closes the socket on the
        other end. """
    other_end = established_cons[con]
    del established_cons[con]
    del established_cons[other_end]
    _closed.append(other_end.fileno())
    close(other_end)

_p = poll()
def poll(timeout):
    """ A wrapper around poll.
        Will yield either (None, None) on a timeout or
        (con, status) if a connection is waiting. """
    global _closed
    _closed = []
    polled = _p.poll(timeout) or [(None, None)]
    timeout = -1
    for ret in polled:
        print "Poll'd %s, %s" %(ret[0], poll_status(ret[1]))
        # I don't know where this is coming from...
        if ret[0] == 256: continue

        # If this connection has already been closed,
        # ignore it.
        if ret[0] in _closed: continue

        # If there is an FD, look it up
        # It may also be none, if poll timed out
        if ret[0]: ret = (fd_to_con[ret[0]], ret[1])

        yield ret

def new_connection(con, addr):
    """ Called when a new connection comes in, but we
        don't know what type it is. """
    print "Got connection from %s" %(addr[0])
    reg(con, POLLIN)
    new_con_timeouts[con] = datetime.now() + HTTP_TIMEOUT

def new_ssh_connection():
    """ Called when poll times out.
        Promote all new connections with expired timeouts
        to SSH connections. """
    now = datetime.now()
    for (con, timeout) in new_con_timeouts.items():
        if timeout > now: continue
        print "New SSH connection: %s" %(ss(con))
        new_passthrough(con, 22)

def new_https_connection(con):
    """ Called when a new connection has pending data.
        Promote con to an SSH connection. """
    print "New HTTPS connection: %s" %(ss(con))
    new_passthrough(con, 80)

def new_passthrough(con, port):
    """ Setup a passthrough, connection con to 127.1:port.  """
    print "Passing connection %s through to port %s" %(ss(con), port)
    del new_con_timeouts[con]
    upstream = socket(AF_INET, SOCK_STREAM)
    upstream.connect(('localhost', port))
    reg(upstream)
    established_cons[upstream] = con
    established_cons[con] = upstream

def close_all_sockets():
    for con in fd_to_con.values():
        print "Closing con %s" %(ss(con))
        con.close()
atexit.register(close_all_sockets)

s = socket(AF_INET, SOCK_STREAM)
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
s.bind(('', 4001))
s.listen(5)
reg(s, POLLIN|POLLPRI|POLLOUT)

while True:
    timeout = len(new_con_timeouts) and 500 or -1
    for (con, state) in poll(timeout):
        if con == s:
            # Accept this connection
            new_connection(*s.accept())
        elif con in new_con_timeouts:
            # A new connection got data -- promote to HTTPS
            new_https_connection(con)
        elif con == None:
            # A connection timed out -- promote to SSH
            new_ssh_connection()
        else:
            # We've got data to pass!
            assert con in established_cons, \
                   "Tried to pass data through a connection " \
                   "(%s) that has not been established." %(ss(con))
            assert state & POLLIN, "I'm going to try to read "\
                   "from a socket that will block... Oh no :("
            print "Reading from %s..." %(ss(con))
            data = con.recv(1024)
            if not data:
                assert state & POLLHUP, "Didn't get any data "\
                        "from conection %s, but it's not in a "\
                        "HUP state either... Quite odd." %(ss(con))
                close(con)
                continue
            print "%r... read.  Writing to %s" \
                  %(data[:10], ss(established_cons[con]))
            established_cons[con].send(data)
            print "Done!"
