#!/usr/bin/python3
#
# Generate SSHFP DNS records (RFC4255) from knownhosts files or ssh-keyscan
#
# Copyright 2010 by Xelerance http://www.xelerance.com/ (Paul Wouters)
# Copyright 2012 Paul Wouters <pwouters@redhat.com>
# Copyright 2014 Gerald Turner <gturner@unzane.com>
# Copyright 2015 Jean-Michel Nirgal Vourgere <jmv_deb@nirgal.com>
# Copyright 2015 Dirk Stoecker <github@dstoecker.de>
# Copyright 2019 Kishan Takoordyal <kishan@cyberstorm.mu>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

VERSION = "3.4"

import os
import sys
import subprocess
import optparse
import base64
import time
import hashlib
import hmac
import re
# www.dnspython.org
try:
	import dns.resolver
	import dns.query
	import dns.zone
except ImportError:
	print("sshfp requires the python-dns package from http://www.pythondns.org/", file=sys.stderr)
	print("Fedora: yum install python-dns", file=sys.stderr)
	print("Debian: apt-get install python-dnspython   (NOT python-dns!)", file=sys.stderr)
	print("openSUSE: zypper in python-dnspython", file=sys.stderr)
	sys.exit(1)

global all_hosts
global khfile
global hostnames
global trailing
global quiet
global port
global timeout
global algo

DEFAULT_KNOWN_HOSTS_FILE = "~/.ssh/known_hosts"

def show_version():
	print("sshfp version: " + VERSION, file=sys.stderr)

def create_sshfp(hostname, keytype, keyblob, digesttype):
	"""Creates an SSH fingerprint"""

	if keytype == "ssh-rsa":
		keytype = "1"
	elif keytype == "ssh-dss":
		keytype = "2"
	elif "ecdsa" in keytype:
		keytype = "3"
	elif keytype == "ssh-ed25519":
		# TBD http://tools.ietf.org/html/draft-moonesamy-sshfp-ed25519-01
		keytype = "4"
	elif keytype == "ssh-xmss":
		keytype = "5"
	else:
		return ""
	if digesttype == "sha1":
		digest = hashlib.sha1
		digesttype = "1"
	elif digesttype == "sha256":
		digest = hashlib.sha256
		digesttype = "2"
	else:
		return ""
	try:
		rawkey = base64.b64decode(keyblob)
	except TypeError:
		print("FAILED on hostname "+hostname+" with keyblob "+keyblob, file=sys.stderr)
		return "ERROR"
	fp = digest(rawkey).hexdigest().upper()
	# check for Reverse entries
	reverse = 1
	parts = hostname.split(".", 3)
	if parts[0] != hostname:
		for octet in parts:
			if not octet.isdigit():
				reverse = 0
		if reverse:
			hostname = parts[3] + "." + parts[2] + "." + parts[1] + "." + parts[0] + ".in-addr.arpa."
	# we don't know wether we need a trailing dot :(
	# eg if someone did "ssh ns.foo" we don't know if this really is "ns.foo." or "ns.foo" plus resolv.conf domainname
	if trailing and not reverse:
		if hostname[-1:] != ".":
			hostname = hostname + "."
	return hostname + " IN SSHFP " + keytype + " " + digesttype + " " + fp

def get_known_host_entry(known_hosts, host):
	"""Get a single entry out of a known_hosts file

	Uses the ssh-keygen utility."""
	cmd = ["ssh-keygen", "-f", known_hosts, "-F", host]
	process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
	(stdout, stderr) = process.communicate()
	if not quiet and stderr:
		print(stderr, file=sys.stderr)
	outputl = []
	for e in stdout.split("\n"):
		if not e.startswith("#"):
			outputl.append(e)
	return "\n".join(outputl)

def sshfp_from_file(khfile, wantedHosts):
	global all_hosts
	global algo

	# ok, let's do it
	known_hosts = os.path.expanduser(khfile)
	try:
		khfp = open(known_hosts)
	except IOError:
		print("ERROR: failed to open file "+ known_hosts, file=sys.stderr)
		sys.exit(1)
	hashed_known_hosts = False
	if khfp.readline().startswith("|1|"):
		hashed_known_hosts = True
	khfp.close()

	fingerprints = []
	if hashed_known_hosts and all_hosts:
		print("ERROR: %s is hashed, cannot use with -a" % known_hosts, file=sys.stderr)
		sys.exit(1)
	elif hashed_known_hosts: #only looking for some known hosts
		for host in wantedHosts:
			fingerprints.append(process_records(get_known_host_entry(known_hosts, host), (host,)))
	else:
		try:
			khfp = open(known_hosts)
		except IOError:
			print("ERROR: failed to open file "+ known_hosts, file=sys.stderr)
			sys.exit(1)
		data = khfp.read()
		khfp.close()
		fingerprints.append(process_records(data, wantedHosts))
	return "\n".join(fingerprints)

def check_keytype(keytype, hostname):
	global algos
	for algo in algos:
		if algo == "dsa": # different name in user interface and files
			algo = "dss"
		if algo in keytype:
			return True
	if not quiet:
		print("Could only find key type %s for %s" % (keytype, hostname), file=sys.stderr)
	return False

def check_hashed(entry, hostnames):
	(hashtype, salt64, hash64) = entry[1:].split("|")
	if hashtype == "1":
		salt = base64.b64decode(salt64)
		hash = base64.b64decode(hash64)
		for host in hostnames:
			if hash == hmac.new(salt, host.encode(encoding="utf-8"), hashlib.sha1).digest():
				return host
	return ""

def process_records(data, hostnames):
	"""Process all records in a string.

	If the global "all_hosts" is True, then return SSHFP entries
	for all records with the allowed key types.

	If "all_hosts is False and hostnames is non-empty, return
	only the items in hostnames
	"""
	global digests
	all_records = []
	for record in data.split("\n"):
		if record.startswith("#") or record.startswith("@") or not record:
			continue
		try:
			# Ignore optionnal 4th column "root@localhost"
			(host, keytype, key) = record.split(" ")[:3]
		except ValueError:
			if not quiet:
				print("Print unable to read record '%s'" % record, file=sys.stdout)
			continue
		if "," in host:
			host = host.split(",")[0]
		elif host[0] == '|':
			host = check_hashed(host, hostnames)
		if (all_hosts and host) or host in hostnames:
			if not check_keytype(keytype, host):
				continue
			for digesttype in digests:
				value = create_sshfp(host, keytype, key, digesttype)
				if not value in all_records:
					all_records.append(value)
	if all_records:
		all_records.sort()
		return "\n".join(all_records)
	else:
		return ""

def get_record(domain, qtype):
	try:
		answers = dns.resolver.query(domain, qtype)
	except dns.resolver.NXDOMAIN:
		#print "NXdomain: "+domain
		return 0
	except dns.resolver.NoAnswer:
		#print "NoAnswer: "+domain
		return 0
	for rdata in answers:
		# just return first entry we got, answers[0].target does not work
		if qtype == "A":
			return rdata
		if qtype == "NS":
			return str(rdata.target)
		else:
			print("ERROR: error in get_record, unknown type " + qtype, file=sys.stderr)
			sys.exit(1)

def get_axfr_record(domain, nameserver):
	try:
		zone = dns.zone.from_xfr(dns.query.xfr(nameserver, domain, rdtype=dns.rdatatype.AXFR))
	except dns.exception.FormError:
		raise dns.exception.FormError(domain)
	else:
		return  zone

def sshfp_from_axfr(domain, nameserver):
	if " " in domain:
		print("ERROR: space in domain '"+domain+"' can't be right, aborted", file=sys.stderr)
		sys.exit(1)
	if not nameserver:
		nameserver = get_record(domain, "NS")
		if not nameserver:
			print("WARNING: no NS record found for domain "+domain+". trying as host record instead", file=sys.stderr)
			# better then nothing
			return sshfp_from_dns([domain])
	hosts = []
	#print "nameserver:" + str(ns)
	try:
		# print "trying axfr for "+domain+"@"+nameserver
		axfr = get_axfr_record(domain, nameserver)
	except dns.exception.FormError as badDomain:
		print("AXFR error: %s - No permission or not authorative for %s; aborting" % (nameserver, badDomain), file=sys.stderr)
		sys.exit(1)

	for (name, ttl, rdata) in axfr.iterate_rdatas('A'):
		#print "name:" +str(name) +", ttl:"+ str(ttl)+ ", rdata:"+str(rdata)
		if "@" in str(name): 
			hosts.append(domain + ".")
		else:
			if not str(name) == "localhost":
				hosts.append("%s.%s." % (name, domain))
	return sshfp_from_dns(hosts)

def sshfp_from_dns(hosts):
	global quiet
	global port
	global timeout
	global algos

	if "dsa" in algos:
		print("WARNING: openssh has obsoleted dsa/dss keys", file=sys.stderr)

	cmd = ["ssh-keyscan", "-p", str(port), "-T", str(timeout), "-t", ",".join(algos)] + hosts
	process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
	(stdout, stderr) = process.communicate()
	stderr = re.sub("#.*\n", "", stderr) # strip scan comments
	if not quiet:
		print(stderr, file=sys.stderr)
	khdns = stdout
	return process_records(khdns, hosts)
	
def main():
	global all_hosts
	global trailing
	global quiet
	global port
	global timeout
	global algos
	global digests

	parser = optparse.OptionParser()
	parser.add_option("-k", "--knownhosts", "--known-hosts", 
			action="store",
			dest="known_hosts",
			metavar="KNOWN_HOSTS_FILE",
			help="obtain public ssh keys from the known_hosts file KNOWN_HOSTS_FILE")
	parser.add_option("-s", "--scan", 
			action="store_true",
			default=False,
			dest="scan",
			help="scan the listed hosts for public keys using ssh-keyscan")
	parser.add_option("-a", "--all",
			action="store_true",
			dest="all_hosts",
			help="scan all hosts in the known_hosts file when used with -k. When used with -s, attempt a zone transfer to obtain all A records in the domain provided")
	parser.add_option("-d", "--trailing-dot",
			action="store_true",
			dest="trailing_dot",
			help="add a trailing dot to the hostname in the SSHFP records")
	parser.add_option("-o", "--output",
			action="store",
			metavar="FILENAME",
			default=None,
			help="write to FILENAME instead of stdout")
	parser.add_option("-p", "--port",
			action="store",
			metavar="PORT",
			type="int",
			default=22,
			help="use PORT for scanning")
	parser.add_option("-v", "--version",
			action="store_true",
			dest="version",
			help="print version information and exit")
	parser.add_option("-q", "--quiet",
			action="store_true",
			dest="quiet")
	parser.add_option("-T", "--timeout",
			action="store",
			type="int",
			dest="timeout",
			default=5,
			help="scanning timeout (default %default)")
	parser.add_option("-t", "--type",
			action="append",
			type="choice",
			dest="algo",
			choices=["rsa", "ecdsa", "ed25519", "dsa", "xmss"],
			default=[],
			help="key type to fetch (may be specified more than once, default rsa,ecdsa,ed25519,dsa,xmss)")
	parser.add_option("--digest",
			action="append",
			type="choice",
			dest="digest",
			choices=["sha1", "sha256"],
			default=[],
			help="fingerprint hash function (may be specified more than once, default sha1,sha256)")
	parser.add_option("-n", "--nameserver",
			action="store",
			type="string",
			dest="nameserver",
			default="",
			help="nameserver to use for AXFR (only valid with -s -a)")
	(options, args) = parser.parse_args()

	# parse options
	khfile = options.known_hosts or DEFAULT_KNOWN_HOSTS_FILE
	dodns = options.scan
	nameserver = ""
	domain = ""
	output = options.output
	quiet = options.quiet
	data = ""
	trailing = options.trailing_dot
	timeout = options.timeout
	algos = options.algo or ["rsa", "ecdsa", "ed25519", "xmss"]
	for algo in algos:
		cmd = ["ssh-keyscan", "-t", algo]
		process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
		(stdout, stderr) = process.communicate()
		if "Unknown key type" in stderr:
			algos.remove(algo)
			if not quiet:
				print("WARNING: key type %s not supported" % algo, file=sys.stderr)
	digests = options.digest or ["sha256"]
	all_hosts = options.all_hosts
	port = options.port
	hostnames = ()

	if options.version:
		show_version()
		sys.exit(0)
	if not quiet and port != 22:
		print("WARNING: non-standard port numbers are not designated in SSHFP records", file=sys.stderr)
	if not quiet and options.known_hosts and options.scan:
		print("WARNING: Ignoring known hosts option -k, -s was passed", file=sys.stderr)
	if options.nameserver and not options.scan and not options.all_hosts:
		print("ERROR: Cannot specify -n without -s and -a", file=sys.stderr)
		sys.exit(1)
	if not options.scan and options.all_hosts and args:
		print("WARNING: -a and hosts both passed, ignoring manual host list", file=sys.stderr)
	if not args and (not all_hosts):
		print("WARNING: Assuming -a", file=sys.stderr)
		all_hosts = True

	if options.scan and options.all_hosts:
		datal = []
		for arg in args:
			datal.append(sshfp_from_axfr(arg, options.nameserver))
		if not quiet:
			datal.insert(0, ";")
			datal.insert(0, "; Generated by sshfp %s from %s at %s" % (VERSION, nameserver, time.ctime()))
			datal.insert(0, ";")
		data = "\n".join(datal)
	elif options.scan:	# Scan specified hosts
		if not args:
			print("ERROR: You asked me to scan, but didn't give any hosts to scan", file=sys.stderr)
			sys.exit(1)
		data = sshfp_from_dns(args)
	else: # known_hosts
		data = sshfp_from_file(khfile, args)

	if output:
		try:
			fp = open(output, "w")
		except IOError:
			print("ERROR: can't open '%s'' for writing" % output, file=sys.stderr)
			sys.exit(1)
		else:
			fp.write(data)
			fp.close()
	else:
		print(data)

if __name__ == "__main__":
	main()
