#!/usr/bin/python
#
# Copyright (c) 2012, 2013, NORDUnet A/S
# All rights reserved.
#
#   Redistribution and use in source and binary forms, with or
#   without modification, are permitted provided that the following
#   conditions are met:
#
#     1. Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#     2. Redistributions in binary form must reproduce the above
#        copyright notice, this list of conditions and the following
#        disclaimer in the documentation and/or other materials provided
#        with the distribution.
#     3. Neither the name of the NORDUnet nor the names of its
#        contributors may be used to endorse or promote products derived
#        from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author : Fredrik Thulin <fredrik@thulin.net>
#
"""
MITM (man in the middle) proxy of pkcs11-proxy. Written while learning
how the (not very documented) wire format worked, as well as to help
diagnose (and fix) a couple of bugs in the C implementation.

Usage :

 On console 1 :

   $ PKCS11_DAEMON_SOCKET="tcp://127.0.0.1:2345" pkcs11-daemon /usr/lib/libsofthsm.so

 On console 2 :

   $ ./p11proxy-mitm --debug

 On console 3 :

   $ PKCS11_PROXY_SOCKET="tcp://127.0.0.1:2344" pkcs11-tool --show-info --module libpkcs11-proxy.so

"""

import sys
import struct
import socket
import logging
import argparse
import SocketServer
import PyKCS11 # used to get all the PKCS#11 defines

default_host = "localhost"
default_port_in = 2344
default_port_out = 2345
default_timeout = 3
default_debug = False

args = None
yhsm = None
logger = None

# Extracted from gck-rpc-private.h. The last two elements of the tuples describe
# the requests/responses.
#
# /*
#  *  a_ = prefix denotes array of _
#  *  A  = CK_ATTRIBUTE
#  *  f_ = prefix denotes buffer for _
#  *  M  = CK_MECHANISM
#  *  u  = CK_ULONG
#  *  s  = space padded string
#  *  v  = CK_VERSION
#  *  y  = CK_BYTE
#  *  z  = null terminated string
#  */
request_list = [
    (0, "ERROR", None, None),
    (1, "C_Initialize", "ay", ""),
    (2, "C_Finalize", "", ""),
    (3, "C_GetInfo", "", "vsusv"),
    (4, "C_GetSlotList", "yfu", "au"),
    (5, "C_GetSlotInfo", "u", "ssuvv"),
    (6, "C_GetTokenInfo", "u", "ssssuuuuuuuuuuuvvs"),
    (7, "C_GetMechanismList", "ufu", "au"),
    (8, "C_GetMechanismInfo", "uu", "uuu"),
    (9, "C_InitToken", "uays", ""),
    (10, "C_WaitForSlotEvent", "u", "u"),
    (11, "C_OpenSession", "uu", "u"),
    (12, "C_CloseSession", "u", ""),
    (13, "C_CloseAllSessions", "u", ""),
    (14, "C_GetFunctionStatus", "u", ""),
    (15, "C_CancelFunction", "u", ""),
    (16, "C_GetSessionInfo", "u", "uuuu"),
    (17, "C_InitPIN", "uay", ""),
    (18, "C_SetPIN", "uayay", ""),
    (19, "C_GetOperationState", "ufy", "ay"),
    (20, "C_SetOperationState", "uayuu", ""),
    (21, "C_Login", "uuay", ""),
    (22, "C_Logout", "u", ""),
    (23, "C_CreateObject", "uaA", "u"),
    (24, "C_CopyObject", "uuaA", "u"),
    (25, "C_DestroyObject", "uu", ""),
    (26, "C_GetObjectSize", "uu", "u"),
    (27, "C_GetAttributeValue", "uufA", "aAu"),
    (28, "C_SetAttributeValue", "uuaA", ""),
    (29, "C_FindObjectsInit", "uaA", ""),
    (30, "C_FindObjects", "ufu", "au"),
    (31, "C_FindObjectsFinal", "u", ""),
    (32, "C_EncryptInit", "uMu", ""),
    (33, "C_Encrypt", "uayfy", "ay"),
    (34, "C_EncryptUpdate", "uayfy", "ay"),
    (35, "C_EncryptFinal", "ufy", "ay"),
    (36, "C_DecryptInit", "uMu", ""),
    (37, "C_Decrypt", "uayfy", "ay"),
    (38, "C_DecryptUpdate", "uayfy", "ay"),
    (39, "C_DecryptFinal", "ufy", "ay"),
    (40, "C_DigestInit", "uM", ""),
    (41, "C_Digest", "uayfy", "ay"),
    (42, "C_DigestUpdate", "uay", ""),
    (43, "C_DigestKey", "uu", ""),
    (44, "C_DigestFinal", "ufy", "ay"),
    (45, "C_SignInit", "uMu", ""),
    (46, "C_Sign", "uayfy", "ay"),
    (47, "C_SignUpdate", "uay", ""),
    (48, "C_SignFinal", "ufy", "ay"),
    (49, "C_SignRecoverInit", "uMu", ""),
    (50, "C_SignRecover", "uayfy", "ay"),
    (51, "C_VerifyInit", "uMu", ""),
    (52, "C_Verify", "uayay", ""),
    (53, "C_VerifyUpdate", "uay", ""),
    (54, "C_VerifyFinal", "uay", ""),
    (55, "C_VerifyRecoverInit", "uMu", ""),
    (56, "C_VerifyRecover", "uayfy", "ay"),
    (57, "C_DigestEncryptUpdate", "uayfy", "ay"),
    (58, "C_DecryptDigestUpdate", "uayfy", "ay"),
    (59, "C_SignEncryptUpdate", "uayfy", "ay"),
    (60, "C_DecryptVerifyUpdate", "uayfy", "ay"),
    (61, "C_GenerateKey", "uMaA", "u"),
    (62, "C_GenerateKeyPair", "uMaAaA", "uu"),
    (63, "C_WrapKey", "uMuufy", "ay"),
    (64, "C_UnwrapKey", "uMuayaA", "u"),
    (65, "C_DeriveKey", "uMuaA", "u"),
    (66, "C_SeedRandom", "uay", ""),
    (67, "C_GenerateRandom", "ufy", "ay"),
    ]

def parse_args():
    """
    Parse the command line arguments.
    """
    parser = argparse.ArgumentParser(description = "pkcs11-proxy man-in-the-middle",
                                     add_help = True,
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter,
                                     )

    parser.add_argument('-H', '--host',
                        dest='listen_host',
                        default=default_host,
                        help='Host address to listen on',
                        metavar='HOST',
                        )

    parser.add_argument('--host-out',
                        dest='connect_to_host',
                        default=default_host,
                        help='Host address to connect to',
                        metavar='HOST',
                        )

    parser.add_argument('--port-in',
                        dest='listen_port',
                        type=int,
                        default=default_port_in,
                        help='Port to listen on',
                        metavar='PORT',
                        )

    parser.add_argument('--port-out',
                        dest='connect_to_port',
                        type=int,
                        default=default_port_out,
                        help='Port to connect to',
                        metavar='PORT',
                        )

    parser.add_argument('--timeout',
                        dest='timeout',
                        type=int, default=default_timeout,
                        required=False,
                        help='Request timeout in seconds',
                        metavar='SECONDS',
                        )

    parser.add_argument('--debug',
                        dest='debug',
                        action='store_true', default=default_debug,
                        help='Enable debug operation',
                        )

    return parser.parse_args()

class ProxyHandler(SocketServer.StreamRequestHandler):
    """
    The RequestHandler class for our server.

    It is instantiated once per connection to the server, and must
    override the handle() method to implement communication to the
    client.
    """

    def __init__(self, *other_args, **kwargs):
        self.timeout = args.timeout

        # Outgoing connection to the real server (pkcs11-daemon)
        self.p11proxy = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.p11proxy.connect((args.connect_to_host, args.connect_to_port))

        self.logger = logger
        SocketServer.BaseRequestHandler.__init__(self, *other_args, **kwargs)

    def handle(self):
        """
        Handle an incoming connection. Read requests/responses and decode them after passing
        them on to the real pkcs11-proxy client/server.
        """
        # self.request is the TCP socket connected to the client
        try:
            # 64 bits app id. Randomized on the client, and not really used in the server.
            self.appid = self.request.recv(8)
            self.p11proxy.send(self.appid)

            while True:
                # Read request. First 32 bits are length.
                length = self.request.recv(4)
                if not length:
                    break
                # Pass length to real server.
                self.p11proxy.send(length)
                self.logger.debug("\n\nRequest length is {}".format(length.encode('hex')))
                self._handle_request(struct.unpack('>L', length)[0])

                # Read response. Again, first 32 bits are length.
                length = self.p11proxy.recv(4)
                # Pass length to client.
                self.request.sendall(length)
                self.logger.debug("\n\nAnswer length is {}".format(length.encode('hex')))
                self._handle_answer(struct.unpack('>L', length)[0])
        except Exception, e:
            self.logger.exception("Got exception handling request from {}".format(self.client_address[0]))
            return None

    def _handle_request(self, length):
        """
        Given expected length, read the rest of a request, pass it on and decode whatever we can.
        """
        self.logger.debug("Request({})".format(length))
        data = self.request.recv(length)
        self.logger.debug(" R: {}".format(data.encode('hex')))
        self.p11proxy.send(data)
        # parse request
        (call_id, rest) = self._get_uint32(data)
        if call_id > len(request_list):
            return None
        self.logger.debug(" R: call_id {}".format(call_id))
        rinfo = request_list[call_id]
        self.logger.debug(" R: name {}".format(rinfo[1]))
        fmt = rinfo[2]
        if fmt:
            self.logger.debug(" R: fmt {} ({})".format(fmt, fmt.encode('hex')))
            return self._parse_message('R', fmt, rest)

    def _handle_answer(self, length):
        """
        Given expected length, read the rest of a response, pass it on and decode whatever we can.
        """
        self.logger.debug("Answer({})".format(length))
        data = self.p11proxy.recv(length)
        self.logger.debug(" A: {}".format(data.encode('hex')))
        self.request.sendall(data)
        (call_id, rest) = self._get_uint32(data)
        if call_id > len(request_list):
            return None
        # parse answer
        (call_id, rest) = self._get_uint32(data)
        if call_id > len(request_list):
            return None
        self.logger.debug(" A: call_id {}".format(call_id))
        rinfo = request_list[call_id]
        self.logger.debug(" A: name {}".format(rinfo[1]))
        if not call_id:
            # This is ERROR
            (ckr, _) = self._get_uint32(data[-4:])
            self.logger.debug(" A: {}".format(PyKCS11.CKR[ckr]))
        fmt = rinfo[3]
        if fmt:
            self.logger.debug(" A: fmt {} ({})".format(fmt, fmt.encode('hex')))
            self._parse_message('A', fmt, rest)

    def _parse_message(self, msgtype, fmt, data):
        """
        Parse the data in the request/response, which share format.
        """
        # Check that the data is of the format we expect it to be. The format of every request/answer
        # is redundantly included in the data sent over the network, although we already know what
        # it should be. I guess this is a way to reduce the likeliness of undetected stream de-sync,
        # because it would be wasteful if it was just used to make developers forget to update the
        # handshake string when changing the format for a request/response.
        (parsed_fmt, rest) = self._get_byte_array(data)
        if parsed_fmt != fmt:
            self.logger.error(" {}: Format mismatch, mine is {} and received is {}".format(msgtype, fmt, parsed_fmt))
            return None
        self.logger.debug(" {}: format match".format(msgtype))
        try:
            res = self._value_dump(msgtype, parsed_fmt, rest)
            return res
        except Exception:
            self.logger.exception("Got exception trying to dump all data")
            return []

    def _value_dump(self, msgtype, fmt, data, res = []):
        """
        Decode the data in a request/answer according to fmt.

        Will output the decoded data using self.logger.debug().
        """
        if not fmt:
            if data:
                self.logger.warning("{} bytes left after processing : {}".format(len(data), data.encode('hex')))
            return []
        #self.logger.debug("PARSING {} FROM {}".format(fmt[0], data.encode('hex')))
        if fmt[0] == 'u':
            (this, rest) = self._get_uint64(data)
            res.append(this)
            self.logger.debug(" {}: uint64 {}".format(msgtype, hex(this)))
        elif fmt[0] == 'a':
            #self.logger.debug("PARSING {} FROM {}".format(fmt[0], data.encode('hex')))
            if fmt[1] == 'A':
                valid = 1
                rest = data
            else:
                (valid, rest) = self._get_uint8(data)
                self.logger.debug(" {}: Byte-Array valid : {}".format(msgtype, valid))
                array_num = 0
            (array_num, rest) = self._get_uint32(rest)
            self.logger.debug(" {}: Byte-Array length : {}/{}".format(msgtype, array_num, hex(array_num)))
            if array_num > 256:
                raise Exception("Unreasonable array length : {}".format(hex(array_num)))
            if not valid:
                #  /* If not valid, then just the length is encoded, this can signify CKR_BUFFER_TOO_SMALL */
                self.logger.debug(" {}: Array length {}/{}, not valid - maybe CKR_BUFFER_TOO_SMALL".format(msgtype, array_num, array_num))
                # array is now fully handled, remove next element of fmt
                fmt = fmt[1:]
                if rest:
                    self.logger.debug(" {}: CKR_BUFFER_TOO_SMALL leaving data {} ({} bytes) for fmt {}".format(msgtype, rest.encode('hex'), len(rest), fmt[1:]))
            else:
                if fmt[1] == 'y':
                    # simple byte array (i.e. string to us)
                    rest = struct.pack('>L', array_num) + rest  # restore length before calling _get_byte_array
                    (this, rest) = self._get_byte_array(rest)
                    self.logger.debug(" {}: Byte-Array({}) {}".format(msgtype, len(this), repr(this)))
                    res.append(this)
                    fmt = fmt[1:]
                else:
                    if array_num > 0:
                        # modify fmt to match the array length
                        placeholder = 'X'
                        array_elements = fmt[1] * (array_num - 1)
                        fmt_rest = fmt[1:]
                        fmt = placeholder + array_elements + fmt_rest
                    else:
                        # chomp the next fmt too
                        fmt = fmt[1:]
                    self.logger.debug(" {}: Array length {}, new fmt {}".format(msgtype, array_num, fmt))
        elif fmt[0] == 's':
            (this, rest) = self._get_byte_array(data)
            res.append(this)
            self.logger.debug(" {}: string({}/{}) {}".format(msgtype, len(this), hex(len(this)), repr(this)))
        elif fmt[0] == 'v':
            # version, struct ck_version from pkcs11.h
            (major, rest) = self._get_uint8(data)
            (minor, rest) = self._get_uint8(rest)
            self.logger.debug(" {}: CK_Version {}.{}".format(msgtype, major, minor))
        elif fmt[0] == 'y':
            (this, rest) = self._get_uint8(data)
            self.logger.debug(" {}: CK_Byte {}".format(msgtype, hex(this)))
        elif fmt[0] == 'z':
            # NULL-terminated string, len == -1 for empty string
            (length, rest) = self._get_uint32(data)
            if length == 0xffffffff:
                res.append('')
            else:
                (this, rest) = self._get_byte_array(data)
                res.append(this)
                self.logger.debug(" {}: NULL-string({}/{}) {}".format(msgtype, len(this), hex(len(this)), repr(this)))
        elif fmt[0] == 'M':
            # Mechanism
            #self.logger.debug("PARSING {} FROM {}".format(fmt[0], data.encode('hex')))
            (mech_type, rest) = self._get_uint32(data)
            self.logger.debug(" {}: Mechanism {}/{}".format(msgtype, repr(mech_type), PyKCS11.CKM.get(mech_type)))
            (this, rest) = self._get_byte_array(rest)
            self.logger.debug(" {}: Mechanism parameter {}".format(msgtype, repr(this)))
            res.append((mech_type, this))
        elif fmt[0] == 'A':
            # Attribute
            (this, rest) = self._get_attribute(data)
            res.append(this)
            #self.logger.debug(" {}: Attribute {}".format(msgtype, this))
            vl = "None"
            if this['uValueLen'] is not None:
                vl = hex(this['uValueLen'])
            self.logger.debug(" {}: Attribute type {}/{}, valid {}, len {}, data {}".format(\
                    msgtype, hex(this['type']), PyKCS11.CKA.get(this['type']), this['valid'], vl, this['data'].encode('hex')))
        elif fmt[0] == 'f':
            # array buffer
            if fmt[1] == 'y':
                (flags, rest) = self._get_uint8(data)
            else:
                rest = data
            (num_items, rest) = self._get_uint32(rest)
            if fmt[1] == 'A':
                self.logger.debug(" {}: Attribute buffer, {} elements".format(msgtype, hex(num_items)))
                while num_items:
                    (attr_type, rest) = self._get_uint32(rest)
                    (buffer_len, rest) = self._get_uint32(rest)
                    self.logger.debug(" {}:   type {}/{} len {}".format(msgtype, hex(attr_type), PyKCS11.CKA.get(attr_type), hex(buffer_len)))
                    res.append(('attribute buffer', attr_type, buffer_len))
                    num_items -= 1
            elif fmt[1] == 'y':
                self.logger.debug(" {}: Byte buffer, length {} (flags {})".format(msgtype, hex(num_items), flags))
                # just a number of bytes to alloc
                pass
            elif fmt[1] == 'u':
                # ulong buffer, just an uint32 (?)
                self.logger.debug("PARSING {} FROM {}".format(fmt[0], data.encode('hex')))
                (this, rest) = self._get_uint32(data)
                self.logger.debug(" {}: ulong buffer {}".format(msgtype, hex(this)))
            else:
                raise Exception("Unknown array buffer fmt '{}'".format(fmt[1]))
            # need to munch an extra fmt
            fmt = fmt[1:]
        else:
            self.logger.warn(" {}: STOPPING at fmt {}, data {}".format(msgtype, fmt[0], data.encode('hex')))
            return []
        return self._value_dump(msgtype, fmt[1:], rest, res)

    def _get_uint64(self, data):
        (a, rest) = self._get_uint32(data)
        (b, rest) = self._get_uint32(rest)
        return (a << 32 | b, data[8:])

    def _get_uint32(self, data):
        res = struct.unpack('>L', data[:4])[0]
        return (res, data[4:])

    def _get_uint8(self, data):
        res = ord(data[0])
        return (res, data[1:])

    def _get_byte_array(self, data):
        (length, rest) = self._get_uint32(data)
        if length == 0xffffffff:
            return('', rest)
        if length > len(rest):
            raise Exception('Parse error, not enough bytes left for byte-array')
        res = rest[:length]
        return (res, rest[length:])

    def _get_attribute(self, data):
        attr = {}
        (attr_type, rest) = self._get_uint32(data)
        (attr_valid, rest) = self._get_uint8(rest)
        attr_vallen = None
        data = ''
        if attr_valid:
            (attr_vallen, rest) = self._get_uint32(rest)
            # length doubly included, in byte array again?
            (data, rest) = self._get_byte_array(rest)
            # read data here!
        attr['type'] = attr_type
        attr['valid'] = attr_valid
        attr['uValueLen'] = attr_vallen
        attr['data'] = data
        return (attr, rest)

def main():
    global args
    args = parse_args()

    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
    logging.basicConfig(level=level)
    global logger
    logger = logging.getLogger('p11proxy-mitm')

    # Create the server
    server = SocketServer.TCPServer((args.listen_host, args.listen_port), \
                                        ProxyHandler, \
                                        bind_and_activate=False, \
                                        )
    server.allow_reuse_address = True
    server.timeout = args.timeout
    server.server_bind()     # Manually bind, to support allow_reuse_address
    server.server_activate() # (see above comment)

    # Activate the server; this will keep running until you
    # interrupt the program with Ctrl-C
    server.serve_forever()

if __name__ == '__main__':
    try:
        if main():
            sys.exit(0)
        sys.exit(1)
    except KeyboardInterrupt:
        pass
