#!/usr/bin/env python

import os
import optparse
import sys
import types
import socket
from kznf.kznfnetlink import *
from kznf.nfnetlink import *

KZNL_MSG_COPY_ZONE_DAC=16

AttributeRequiredError = "required attribute missing"

def inet_ntoa(a):
    return "%s.%s.%s.%s" % ((a >> 24) & 0xff, (a >> 16) & 0xff, (a >> 8) & 0xff, a & 0xff)

def inet_aton(a):
    r = 0L
    for n in a.split("."):
        r = (r << 8) + int(n)
    return r

def size_to_mask(n):
    if n == 0:
        return 0
    else:
        return 0xffffffffL  - ((1L << (32 - n)) - 1)

def mask_to_description(mask, definition):
    text = ""
    first = True
    for i in definition.keys():
        if (mask & i):
            if first:
                text = definition[i]
                first = False
            else:
                text = text + ",%s" % (definition[i])
    return text


class DumpBase():
    def __init__(self, quiet, type, create_func):
      self.quiet = quiet
      self.type = type
      self.create_func = create_func
      self.has_data = False

    def dump(self):
        # initialize nfnetlink
        h = Handle()
        h.register_subsystem(Subsystem(NFNL_SUBSYS_KZORP))

        # create dump message
        m = h.create_message(NFNL_SUBSYS_KZORP, self.type, NLM_F_REQUEST | NLM_F_DUMP)
        m.set_nfmessage(self.create_func(None))

        if not self.quiet:
            res = h.talk(m, (0, 0), self._msg_handler)
        else:
            res = h.talk(m, (0, 0), self._msg_handler_quiet)

        if res != 0:
            sys.stderr.write("Dump failed: result='%d' error='%s'\n" % (res, os.strerror(-res)))
            return 1

        if self.quiet:
            if self.has_data:
                res = 0
            else:
                res = 1

        return res

    def _msg_handler(self, msg):
        pass

    def _msg_handler_quiet(self, msg):
      self.has_data = True

class DumpZones(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, KZNL_MSG_GET_ZONE, create_get_zone_msg)

    def _print_add_zone_svc(self, msg, attrs):
        if attrs.has_key(KZA_ZONE_UNAME):
            name = parse_name_attr(attrs[KZA_ZONE_UNAME])
        else:
            raise AttributeRequiredError, "KZA_ZONE_UNAME"

        if attrs.has_key(KZA_SVC_NAME):
            svc_name = parse_name_attr(attrs[KZA_SVC_NAME])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAME"

        print "%s: '%s'" % (msg, svc_name)

    def _print_add_zone(self, attrs):

        zone_flags = {1: "umbrella"}
        
        if attrs.has_key(KZA_ZONE_PARAMS):
            flags = parse_int32_attr(attrs[KZA_ZONE_PARAMS])
        else:
            raise AttributeRequiredError, "KZA_ZONE_PARAMS"

        if attrs.has_key(KZA_ZONE_NAME):
            name = parse_name_attr(attrs[KZA_ZONE_NAME])
        else:
            name = None

        if attrs.has_key(KZA_ZONE_UNAME):
            unique_name = parse_name_attr(attrs[KZA_ZONE_UNAME])
        else:
            raise AttributeRequiredError, "KZA_ZONE_UNAME"

        if attrs.has_key(KZA_ZONE_PNAME):
            admin_parent = parse_name_attr(attrs[KZA_ZONE_PNAME])
        else:
            admin_parent = None

        if attrs.has_key(KZA_ZONE_RANGE):
            (addr, mask) = parse_inet_range_attr(attrs[KZA_ZONE_RANGE])
            range_str = ", range '%s/%s'" % (inet_ntoa(socket.htonl(addr)), inet_ntoa(socket.htonl(mask)))
        else:
            range_str = ""

        flags_str = mask_to_description(flags, zone_flags)

        print "Zone unique_name='%s', visible_name='%s', admin_parent='%s',\n        flags '%s'%s" % \
              (unique_name, name, admin_parent, flags_str, range_str)

    def _msg_handler(self, msg):
        attrs = msg.get_nfmessage().get_attributes()

        if msg.type & 0xff == KZNL_MSG_ADD_ZONE:
            self._print_add_zone(attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_ZONE_SVC_IN:
            self._print_add_zone_svc("        Inbound service: ", attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_ZONE_SVC_OUT:
            self._print_add_zone_svc("        Outbound service: ", attrs)

class DumpServices(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, KZNL_MSG_GET_SERVICE, create_get_service_msg)

    def _nat_range_str(self, flags, ip1, ip2, p1, p2):
        if ip1 == ip2:
            return "%s" % (inet_ntoa(ip1),)
        else:
            return "(%s - %s)" % (inet_ntoa(ip1), inet_ntoa(ip2))

    def _print_add_svc_nat(self, msg, attrs):
        if attrs.has_key(KZA_SVC_NAME):
            name = parse_name_attr(attrs[KZA_SVC_NAME])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAME"

        if attrs.has_key(KZA_SVC_NAT_SRC):
            sflags, sip1, sip2, sp1, sp2 = parse_nat_range_attr(attrs[KZA_SVC_NAT_SRC])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAT_SRC"

        if attrs.has_key(KZA_SVC_NAT_DST):
            dflags, dip1, dip2, dp1, dp2 = parse_nat_range_attr(attrs[KZA_SVC_NAT_DST])
        else:
            dflags = None

        if attrs.has_key(KZA_SVC_NAT_MAP):
            mflags, mip1, mip2, mp1, mp2 = parse_nat_range_attr(attrs[KZA_SVC_NAT_MAP])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAT_MAP"

        if dflags:
            print "%s src %s dst %s mapped to %s" % \
                  (msg, self._nat_range_str(sflags, sip1, sip2, sp1, sp2), \
                   self._nat_range_str(dflags, dip1, dip2, dp1, dp2), \
                   self._nat_range_str(mflags, mip1, mip2, mp1, mp2))
        else:
            print "%s src %s mapped to %s" % \
                  (msg, self._nat_range_str(sflags, sip1, sip2, sp1, sp2), \
                   self._nat_range_str(mflags, mip1, mip2, mp1, mp2))

    def _print_add_svc(self, attrs):

        svc_type = ("INVALID", "Service", "PFService")
        svc_flags = {1: "transparent", 2: "forge_addr"}
        
        if attrs.has_key(KZA_SVC_NAME):
            name = parse_name_attr(attrs[KZA_SVC_NAME])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAME"

        if attrs.has_key(KZA_SVC_PARAMS):
            flags, type = parse_service_params_attr(attrs[KZA_SVC_PARAMS])
        else:
            raise AttributeRequiredError, "KZA_SVC_PARAMS"

        if attrs.has_key(KZA_SVC_SESSION_CNT):
            cnt = parse_int32_attr(attrs[KZA_SVC_SESSION_CNT])
        else:
            cnt = None

        if attrs.has_key(KZA_SVC_ROUTER_DST):
            addr, port, proto = parse_address_attr(attrs[KZA_SVC_ROUTER_DST])
        else:
            addr, port, proto = (None, None, None)

        flags_str = mask_to_description(flags, svc_flags)

        print "Service name='%s', flags='%s', type='%s', session_cnt='%d'" % (name, flags_str, svc_type[type], cnt)

        if addr and port:
            print "        router_dst='%s:%d'" % \
                  (inet_ntoa(addr), port)

    def _msg_handler(self, msg):
        attrs = msg.get_nfmessage().get_attributes()

        if msg.type & 0xff == KZNL_MSG_ADD_SERVICE:
            self._print_add_svc(attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_SERVICE_NAT_SRC:
            self._print_add_svc_nat("        SNAT: ", attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_SERVICE_NAT_DST:
            self._print_add_svc_nat("        DNAT: ", attrs)

class DumpDispatchers(DumpBase):
    def __init__(self, quiet):
      DumpBase.__init__(self, quiet, KZNL_MSG_GET_DISPATCHER, create_get_dispatcher_msg)
      self.rules_value = {}
      self._max = 0
      self._index = 0
      self.dpt_protocols = {6: "TCP", 17: "UDP"}

    def _print_ports(self, ports):
        p = ""
        for s, e in ports:
            if s == e:
                p = "".join((p, "%d," % (s,)))
            else:
                p = "".join((p, "%d:%d," % (s, e)))
        return p.rstrip(",")

    def _print_add_dpt(self, attrs):

        dpt_flags = {1: "transparent", 2: "follow_parent"}

        if attrs.has_key(KZA_DPT_NAME):
            name = parse_name_attr(attrs[KZA_DPT_NAME])
        else:
            raise AttributeRequiredError, "KZA_DPT_NAME"

        if attrs.has_key(KZA_DPT_PARAMS):
            flags, proxy_port, dpt_type = parse_dispatcher_params_attr(attrs[KZA_DPT_PARAMS])
        else:
            raise AttributeRequiredError, "KZA_DPT_PARAMS"

        if dpt_type == KZ_DPT_TYPE_INET:
            if attrs.has_key(KZA_DPT_BIND_ADDR):
                proto, addr, ports = parse_bind_addr_attr(attrs[KZA_DPT_BIND_ADDR])
                proto_str = self.dpt_protocols[proto]
                addr_str = "        proto='%s', addr='%s', proxy_port='%d', num_ranges='%d', ports='%s'" % \
                           (proto_str, inet_ntoa(addr), proxy_port, len(ports), self._print_ports(ports))
            else:
                raise AttributeRequiredError, "KZA_DPT_BIND_ADDR"
        elif dpt_type == KZ_DPT_TYPE_IFACE:
            if attrs.has_key(KZA_DPT_BIND_IFACE):
                proto, iface, ports, pref_addr = parse_bind_iface_attr(attrs[KZA_DPT_BIND_IFACE])
                proto_str = self.dpt_protocols[proto]
                addr_str = "        proto='%s', iface='%s', pref_addr='%s', proxy_port='%d', num_ranges='%d, ports='%s'" % \
                           (proto_str, iface, inet_ntoa(pref_addr), proxy_port, len(ports), self._print_ports(ports))
            else:
                raise AttributeRequiredError, "KZA_DPT_BIND_IFACE"
        elif dpt_type == KZ_DPT_TYPE_IFGROUP:
            if attrs.has_key(KZA_DPT_BIND_IFGROUP):
                proto, group, mask, ports, pref_addr = parse_bind_ifgroup_attr(attrs[KZA_DPT_BIND_IFGROUP])
                proto_str = self.dpt_protocols[proto]
                addr_str = "        proto='%s', ifgroup='%s', pref_addr='%s', proxy_port='%d', num_ranges='%d', ports='%s'" % \
                           (proto_str, group, inet_ntoa(pref_addr), proxy_port, len(ports), self._print_ports(ports))
            else:
                raise AttributeRequiredError, "KZA_DP_BIND_IFGROUP"
        elif dpt_type == KZ_DPT_TYPE_N_DIMENSION:
            num_rules = parse_n_dimension_attr(attrs[KZA_DISPATCHER_N_DIMENSION_PARAMS])
            addr_str = "        proxy_port='%d', num_rules='%d'" % (proxy_port, num_rules)

        flags_str = mask_to_description(flags, dpt_flags)

        print "Dispatcher name='%s' flags='%s'\n%s" % (name, flags_str, addr_str)

    def _print_add_dpt_css(self, attrs):
        if attrs.has_key(KZA_DPT_NAME):
            name = parse_name_attr(attrs[KZA_DPT_NAME])
        else:
            raise AttributeRequiredError, "KZA_DPT_NAME"

        if attrs.has_key(KZA_DPT_CSS_CZONE):
            czone = parse_name_attr(attrs[KZA_DPT_CSS_CZONE])
        else:
            czone = "*"

        if attrs.has_key(KZA_DPT_CSS_SZONE):
            szone = parse_name_attr(attrs[KZA_DPT_CSS_SZONE])
        else:
            szone = "*"

        if attrs.has_key(KZA_SVC_NAME):
            sname = parse_name_attr(attrs[KZA_SVC_NAME])
        else:
            raise AttributeRequiredError, "KZA_SVC_NAME"

        print "        ('%s', '%s') -> '%s'" % (czone, szone, sname)

    def _print_add_rule(self, attrs):
      rule_id, service, rules = parse_rule_attrs(attrs)
      if (rules.values() == []):
        self._max = 0
      else:
        (self._max,) = max(rules.values())

      print "        rule_id='%d', service='%s'" % (rule_id, service)

    def _print_add_rule_entry(self, attrs):
      dimensions = [ (KZA_N_DIMENSION_AUTH     , 'auth'),   (KZA_N_DIMENSION_IFACE    , 'iface'),    (KZA_N_DIMENSION_IFGROUP  , 'ifgroup'), \
                     (KZA_N_DIMENSION_PROTO    , 'proto'),  (KZA_N_DIMENSION_SRC_PORT , 'src_port'), (KZA_N_DIMENSION_DST_PORT , 'dst_port'), \
                     (KZA_N_DIMENSION_SRC_IP   , 'src_ip'), (KZA_N_DIMENSION_SRC_ZONE , 'src_zone'), (KZA_N_DIMENSION_DST_IP   , 'dst_ip'), \
                     (KZA_N_DIMENSION_DST_ZONE , 'dst_zone') ]

      # NOTE: we detect that all entries were received by counting the
      # ADD_RULE_ENTRY messages and comparing that to the max
      # dimension array length. This is OK with the current kernel
      # implementation but may break if we change the kernel.

      rule_id, rule_entries = parse_rule_entry_attrs(attrs)
      for k, v in rule_entries.items():
        if not k in self.rules_value:
          self.rules_value[k] = []
        if k == KZA_N_DIMENSION_SRC_IP or k == KZA_N_DIMENSION_DST_IP:
          self.rules_value[k].append((inet_ntoa(v[0]), inet_ntoa(v[1])))
        elif k == KZA_N_DIMENSION_PROTO:
          self.rules_value[k].append(self.dpt_protocols[v[0]])
        elif k == KZA_N_DIMENSION_SRC_PORT or k == KZA_N_DIMENSION_DST_PORT:
          self.rules_value[k].append((v[0], v[1]))
        else:
          self.rules_value[k].append(v[0])
      self._index += 1
      if self._index == self._max:
        for k in dimensions:
          if k[0] in self.rules_value:
            print "           %s=%s " % (k[1], self.rules_value[k[0]])
        self.rules_value = {}
        self._index = 0
        print ""

    def _msg_handler(self, msg):
        attrs = msg.get_nfmessage().get_attributes()

        if msg.type & 0xff == KZNL_MSG_ADD_DISPATCHER:
            self._print_add_dpt(attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_DISPATCHER_CSS:
            self._print_add_dpt_css(attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_RULE:
            self._print_add_rule(attrs)
        if msg.type & 0xff == KZNL_MSG_ADD_RULE_ENTRY:
            self._print_add_rule_entry(attrs)

def upload_zones(fname):

    def exchange_message(h, msg, payload):
        m = h.create_message(NFNL_SUBSYS_KZORP, msg, NLM_F_REQUEST | NLM_F_ACK)
        m.set_nfmessage(payload)
        result = h.talk(m, (0, 0), None)
        if result != 0:
            raise NfnetlinkException, "Error while talking to KZorp"

    def parse_range(r):
        if r.count("/") == 0:
            # simple IP address
            return ((inet_aton(r), 0xffffffffL))
        else:
            # IP range
            (a, m) = r.split("/")
            return (inet_aton(a), size_to_mask(int(m)))

    def process_line(h, l):
        # skip comments
        if l.startswith("#"):
            return

        zone, parent, umbrella, r = l.split(";")
        zone = zone.strip('"')
        parent = parent.strip('"')
        ranges = r.split(",")

        if int(umbrella) == 1:
            flags = KZF_ZONE_UMBRELLA
        else:
            flags = 0

        if parent == "":
            parent = None

        if len(ranges) <= 1:
            # simple case: no ranges or exactly one range
            if ranges == [""]:
                addr = None
                mask = None
            else:
                addr, mask = parse_range(ranges[0])

            exchange_message(h, KZNL_MSG_ADD_ZONE, \
                create_add_zone_msg(zone, flags, addr, mask, zone, parent))
        else:
            # we send the "parent" first
            exchange_message(h, KZNL_MSG_ADD_ZONE, \
                create_add_zone_msg(zone, flags, None, None, zone, parent))
            # then the rest
            i = 1
            for r in ranges:
                uname = zone + "-#%d" % (i,)
                i = i + 1
                addr, mask = parse_range(r)
                exchange_message(h, KZNL_MSG_ADD_ZONE, \
                    create_add_zone_msg(zone, flags, addr, mask, uname, zone))

    # initialize nfnetlink
    h = Handle()
    h.register_subsystem(Subsystem(NFNL_SUBSYS_KZORP))

    # start zone transaction
    exchange_message(h, KZNL_MSG_START, create_start_msg(KZ_TR_TYPE_ZONE, KZ_INSTANCE_GLOBAL))
    # flush zones
    exchange_message(h, KZNL_MSG_FLUSH_ZONE, create_flush_msg())
    exchange_message(h, KZNL_MSG_COPY_ZONE_DAC, create_flush_msg())

    # process each zone
    f = file(fname)
    while 1:
        l = f.readline()
        if not l: break

        l = l.strip()

        try:
            process_line(h, l)
        except Exception, e:
            sys.stderr.write("Error while processing the following line: %s\n%s\n" % (e, l))
            return 1

    # commit transaction
    exchange_message(h, KZNL_MSG_COMMIT, create_commit_msg())

    return 0

def evaluate(parser, args, quiet):
  def parse_ip(parser, ip, description):
    try:
      return inet_aton(ip)
    except socket.error:
      parser.error("invalid %s ip: %s" % (description, ip))

  def parse_port(parser, port, description):
    try:
      p = int(port)
      if (0 < p < 65535):
        return p
      else:
        raise ValueError, ""
    except ValueError:
      parser.error("invalid %s port: %s" % (description, port))

  def handle_reply(r):
    attrs = r.get_nfmessage().get_attributes()
    client_zone = "not found"
    if attrs.has_key(KZA_QUERY_CLIENT_ZONE):
      client_zone = parse_name_attr(attrs[KZA_QUERY_CLIENT_ZONE])
    server_zone = "not found"
    if attrs.has_key(KZA_QUERY_SERVER_ZONE):
      server_zone = parse_name_attr(attrs[KZA_QUERY_SERVER_ZONE])
    service = "not found"
    if attrs.has_key(KZA_SVC_NAME):
      service = parse_name_attr(attrs[KZA_SVC_NAME])
    dispatcher = "not found"
    if attrs.has_key(KZA_DPT_NAME):
      dispatcher = parse_name_attr(attrs[KZA_DPT_NAME])
    print "Client zone: %s\nServer zone: %s\nService: %s\nDispatcher: %s" % \
          (client_zone, server_zone, service, dispatcher)

  if args[0].lower() == "tcp":
    proto = socket.IPPROTO_TCP
  elif args[0].lower() == "udp":
    proto = socket.IPPROTO_UDP
  else:
    parser.error('protocol must be "tcp" or "udp"')

  saddr = parse_ip(parser,   args[1], "client")
  sport = parse_port(parser, args[2], "client")
  daddr = parse_ip(parser,   args[3], "server")
  dport = parse_port(parser, args[4], "server")

  if len(args[5]) > 16:
    parser.error('invalid interface name (>16 characters)')

  iface = args[5]

  if not quiet:
    print "evaluating %s:%d -> %s:%d on %s" % (inet_ntoa(saddr), sport, inet_ntoa(daddr), dport, iface)

  h = Handle()
  h.register_subsystem(Subsystem(NFNL_SUBSYS_KZORP))
  kzorp_m = h.create_message(NFNL_SUBSYS_KZORP, KZNL_MSG_QUERY, NLM_F_REQUEST | NLM_F_ACK)
  m = create_query_msg(proto, saddr, sport, daddr, dport, iface)
  kzorp_m.set_nfmessage(m)

  res = h.talk(kzorp_m, (0, 0), handle_reply)

def main(args):
    option_list = [
                     optparse.make_option("-z", "--zones",
                                          action="store_true", dest="zones",
                                          default=False,
                                          help="dump KZorp zones "
                                               "[default: %default]"),
                     optparse.make_option("-s", "--services",
                                          action="store_true", dest="services",
                                          default=False,
                                          help="dump KZorp services "
                                               "[default: %default]"),
                     optparse.make_option("-d", "--dispatchers",
                                          action="store_true", dest="dispatchers",
                                          default=False,
                                          help="dump KZorp dispatchers "
                                               "[default: %default]"),
                     optparse.make_option("-e", "--evaluate",
                                          dest="evaluate",
                                          type="string",
                                          nargs=6,
                                          default=None,
                                          help="evaluate "
                                               "arguments: <protocol> <client address> <client port> <server address> <server port> <interface name>"),
                     optparse.make_option("-q", "--quiet",
                                          action="store_true", dest="quiet",
                                          default=False,
                                          help="quiet operation "
                                               "[default: %default]"),
                     optparse.make_option("-u", "--upload",
                                          action="store", type="string", dest="upload",
                                          default=None,
                                          help="upload KZorp zone structure from file "
                                               "[default: %default]")
                  ]

    parser = optparse.OptionParser(option_list=option_list, prog="kzorp", usage = "usage: %prog [options]")
    (options, args) = parser.parse_args()

    if (options.zones or options.services or options.dispatchers or options.upload != None or options.evaluate != None) == False:
      parser.error("at least one option must be set")

    if os.getuid() != 0:
      sys.stderr.write("kzorp must be run as root\n")
      return 2

    res = 3
    try:
      if options.zones:
          dump_zones = DumpZones(options.quiet)
          res = dump_zones.dump()
      if options.services:
          dump_services = DumpServices(options.quiet)
          res = dump_services.dump()
      if options.dispatchers:
          dump_dispatchers = DumpDispatchers(options.quiet)
          res = dump_dispatchers.dump()
      if options.upload:
          res = upload_zones(options.upload)
      if options.evaluate:
          res = evaluate(parser, options.evaluate, options.quiet)
    except socket.error, e:
        if e[0] == 111:
            sys.stderr.write("KZorp support not present in kernel\n")
            return 2
        raise

    return res

if __name__ == "__main__":
    res = main(sys.argv)
    sys.exit(res)
