#!/usr/local/bin/python3.6 -u
# Copyright (c) 2012, Jakob Borg
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#    * Redistributions of source code must retain the above copyright
#      notice, this list of conditions and the following disclaimer.
#    * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#    * The name of the author may not be used to endorse or promote products
#      derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY JAKOB BORG ''AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
# EVENT SHALL JAKOB BORG BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
# OF SUCH DAMAGE.
# Initial code taken from: https://github.com/jm66/solaris-extra-snmp

import errno
import json
import os
import re
import socket
import sys
import time
import subprocess
import libzfs
import syslog
import snmp_passpersist as snmp
import django

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'freenasUI.settings')
sys.path.append("/usr/local/www")

django.setup()

from freenasUI.storage.models import Replication
from freenasUI.tools.arc_summary import get_Kstat, get_arc_efficiency




POOLING_INTERVAL = 10  # Update timer, in second
MAX_RETRY = 10  # Number of successives retry in case of error
OID_BASE = '.1.3.6.1.4.1.25359.1'
pp = None
ARC = get_arc_efficiency(get_Kstat())
FREENASSNMPDSOCK = '/var/run/freenas-snmpd.sock'
size_dict = {
	"K": 1024,
	"M": 1048576,
	"G": 1073741824,
	"T": 1099511627776
}


def get_from_freenas_snmpd_sock(val_to_obtain):
	data = b''
	try:
		s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
		s.connect(FREENASSNMPDSOCK)
		s.sendall(val_to_obtain)
		while True:
			text = s.recv(4096)
			if text == b'':
				break
			else:
				data += text
	except socket.error:
		pass
	finally:
		s.close()
	try:
		data = json.loads(data.decode('utf8'))
	except (ValueError, UnicodeDecodeError):
		data = {}
	return data



def update_data():
	global pp

	all_iostat = zpoolio_interval("all")
	onesec_iostat = zpoolio_interval(1)
	zfs = libzfs.ZFS()
	pools = [pool for pool in zfs.pools]
	datasets = []
	zvols = []
	qs = Replication.objects.filter(repl_enabled=True)
	pp.add_gau('2.1.0', zfs_arc_size())
	pp.add_gau('2.2.0', zfs_arc_meta())
	pp.add_gau('2.3.0', zfs_arc_data())
	pp.add_cnt_32bit('2.4.0', zfs_arc_hits())
	pp.add_cnt_32bit('2.5.0', zfs_arc_misses())
	pp.add_gau('2.6.0', zfs_arc_c())
	pp.add_gau('2.7.0', zfs_arc_p())
	pp.add_str('2.8.0', zfs_arc_miss_percent())
	pp.add_str('2.9.0', zfs_arc_cache_hit_ratio())
	pp.add_str('2.10.0', zfs_arc_cache_miss_ratio())
	pp.add_cnt_32bit('3.1.0', zfs_l2arc_hits())
	pp.add_cnt_32bit('3.2.0', zfs_l2arc_misses())
	pp.add_cnt_32bit('3.3.0', zfs_l2arc_read())
	pp.add_cnt_32bit('3.4.0', zfs_l2arc_write())
	pp.add_gau('3.5.0', zfs_l2arc_size())
	pp.add_cnt_64bit('6.1.0', zfs_zilstat_ops1())
	pp.add_cnt_64bit('6.2.0', zfs_zilstat_ops5())
	pp.add_cnt_64bit('6.3.0', zfs_zilstat_ops10())

	for res in map(zfs_segregate, [pool.root_dataset for pool in pools]):
		# Excluding the first item in the datasets list as its always the root_dataset
		# and we already have the stats on that (i.e. the pool!)
		datasets.extend(res[0][1:])
		zvols.extend(res[1])

	for i, zpool in enumerate(pools):
		stri = str(i + 1)
		pool = zpool.name
		pool_health = zpool.properties['health'].value
		# Dividing by 1024 to ge to KB
		pool_used = unprettyprint(zpool.root_dataset.properties['used'].value) / 1024
		pool_available = unprettyprint(zpool.root_dataset.properties['available'].value) / 1024
		pool_size = unprettyprint(zpool.properties['size'].value) / 1024

		pp.add_str('1.1.' + stri, pool)
		pp.add_cnt_64bit('1.2.' + stri, pool_available)
		pp.add_cnt_64bit('1.3.' + stri, pool_used)
		pp.add_str('1.4.' + stri, pool_health)
		pp.add_cnt_64bit('1.5.' + stri, pool_size)
		pp.add_gau('1.12.' + stri, pool_available / 1024)
		pp.add_gau('1.13.' + stri, pool_used / 1024)
		pp.add_gau('1.14.' + stri, pool_size / 1024)
		pp.add_cnt_64bit('1.15.' + stri, all_iostat[pool]['opread'])
		pp.add_cnt_64bit('1.16.' + stri, all_iostat[pool]['opwrite'])
		pp.add_cnt_64bit('1.17.' + stri, all_iostat[pool]['bwread'])
		pp.add_cnt_64bit('1.18.' + stri, all_iostat[pool]['bwrite'])
		pp.add_cnt_64bit('1.19.' + stri, onesec_iostat[pool].get('opread', 0))
		pp.add_cnt_64bit('1.20.' + stri, onesec_iostat[pool].get('opwrite', 0))
		pp.add_cnt_64bit('1.21.' + stri, onesec_iostat[pool].get('bwread', 0))
		pp.add_cnt_64bit('1.22.' + stri, onesec_iostat[pool].get('bwrite', 0))

	for i, zvol in enumerate(zvols):
		stri = str(i + 1)
		volsize = unprettyprint(zvol.properties['volsize'].value) / 1024
		vol_used = unprettyprint(zvol.properties['used'].value) / 1024
		vol_available = unprettyprint(zvol.properties['available'].value) / 1024
		pp.add_str('5.1.' + stri, zvol.name)
		pp.add_cnt_64bit('5.2.' + stri, vol_available)
		pp.add_cnt_64bit('5.3.' + stri, vol_used)
		pp.add_cnt_64bit('5.4.' + stri, volsize)
		pp.add_gau('5.12.' + stri, vol_available / 1024)
		pp.add_gau('5.13.' + stri, vol_used / 1024)
		pp.add_gau('5.14.' + stri, volsize / 1024)

	for i, ds in enumerate(datasets):
		stri = str(i + 1)
		ds_used = unprettyprint(ds.properties['used'].value) / 1024
		ds_available = unprettyprint(ds.properties['available'].value) / 1024
		ds_size = ds_used + ds_available
		pp.add_str('7.1.' + stri, ds.name)
		pp.add_cnt_64bit('7.2.' + stri, ds_available)
		pp.add_cnt_64bit('7.3.' + stri, ds_used)
		pp.add_cnt_64bit('7.4.' + stri, ds_size)
		pp.add_gau('7.12.' + stri, ds_available / 1024)
		pp.add_gau('7.13.' + stri, ds_used / 1024)
		pp.add_gau('7.14.' + stri, ds_size / 1024)

	for i, repl in enumerate(qs):
		stri = str(i + 1)
		name = repl.repl_filesystem
		status = repl.status
		enabled = repl.repl_enabled
		remote_addr = repl.repl_remote
		remote_zfs = repl.repl_zfs
		last = repl.repl_lastsnapshot
		pp.add_str('8.1.' + stri, name)
		pp.add_str('8.2.' + stri, '%s@%s' % (remote_zfs, remote_addr))
		pp.add_str('8.3.' + stri, enabled)
		pp.add_str('8.4.' + stri, status)
		pp.add_str('8.5.' + stri, last)


class ArgumentValidationError(ValueError):
	"""
	Raised when the type of an argument to a function is not what it should be.
	"""

	def __init__(self, arg_num, func_name, accepted_arg_type):
		self.error = 'The {0} argument of {1}() is not a {2}'.format(
			arg_num, func_name, accepted_arg_type)

	def __str__(self):
		return self.error


def unprettyprint(ster):
	"""
	Method to convert 1K --> 1024 and so on...
	"""
	num = 0.0
	try:
		num = float(ster)
	except:
		try:
			num = float(ster[:-1]) * size_dict[ster[-1]]
		except:
			pass
	return int(num)


def kstat(name):
	output = subprocess.getoutput("sysctl kstat." + name)
	try:
		return int(re.split("\s+", output)[1])
	except:
		return 0


def zfs_arc_size():
	# KB
	return kstat("zfs.misc.arcstats.size") / 1024


def zfs_arc_data():
	# KB
	return kstat("zfs.misc.arcstats.data_size") / 1024


def zfs_arc_meta():
	# KB
	return kstat("zfs.misc.arcstats.arc_meta_used") / 1024


def zfs_arc_hits():
	# 32 bit counter
	return kstat("zfs.misc.arcstats.hits") % 2 ** 32


def zfs_arc_misses():
	# 32 bit counter
	return kstat("zfs.misc.arcstats.misses") % 2 ** 32


def zfs_arc_miss_percent():
	# percentage (floating point precision wrapped as a string)
	arc_hits = kstat("zfs.misc.arcstats.hits")
	arc_misses = kstat("zfs.misc.arcstats.misses")
	arc_read = arc_hits + arc_misses
	if (arc_read > 0):
		hit_percent = float(100 * arc_hits / arc_read)
		miss_percent = 100 - hit_percent
		return str(miss_percent)
	return "0"


def zfs_arc_c():
	# KB
	return kstat("zfs.misc.arcstats.c") / 1024


def zfs_arc_p():
	# KB
	return kstat("zfs.misc.arcstats.p") / 1024


def zfs_arc_cache_hit_ratio():
	# percentage (floating point precision wrapped as a string)
	return ARC['cache_hit_ratio']['per'][:-1]


def zfs_arc_cache_miss_ratio():
	# percentage (floating point precision wrapped as a string)
	return ARC['cache_miss_ratio']['per'][:-1]


def zilstatd_ops(interval):
	res = 0
	FSNMPDATA = get_from_freenas_snmpd_sock(b"get_all")
	try:
		res = FSNMPDATA["zil_data"][str(interval)]['ops']
	except KeyError:
		pass
	return res


# Note: Currently only 1 second interval and "all" is supported
# to add more make the appropriate worker in gui/tools/freenas-snmpd.py
def zpoolio_interval(interval):
	res = {}
	FSNMPDATA = get_from_freenas_snmpd_sock(b"get_all")
	try:
		res = FSNMPDATA["zpool_data"][str(interval)]
	except KeyError:
		pass
	return res


def zfs_zilstat_ops1():
	return zilstatd_ops(1)


def zfs_zilstat_ops5():
	return zilstatd_ops(5)


def zfs_zilstat_ops10():
	return zilstatd_ops(10)


def zfs_l2arc_size():
	# KB
	return kstat("zfs.misc.arcstats.l2_size") / 1024


def zfs_l2arc_hits():
	# 32 bit counter
	return kstat("zfs.misc.arcstats.l2_hits") % 2 ** 32


def zfs_l2arc_misses():
	# 32 bit counter
	return kstat("zfs.misc.arcstats.l2_misses") % 2 ** 32


def zfs_l2arc_write():
	# 32 bit KB counter
	return kstat("zfs.misc.arcstats.l2_write_bytes") / 1024 % 2 ** 32


def zfs_l2arc_read():
	# 32 bit KB counter
	return kstat("zfs.misc.arcstats.l2_read_bytes") / 1024 % 2 ** 32


def zfs_segregate(zfs_dataset):
	"""
	A function to obtain and segregte all the datsets (children) and zvols
	in the provided (`zfs_dataset`) dataset. The best example to use this
	is to provide it with a zfs pool's root dataset and it will return a
	a tuple of (datsets, zvols) where each of them is a list.

	Note: Please make sure that the input to this function (`zfs_dataset`)
	is of type: libzfs.ZFSDataset
	"""
	if type(zfs_dataset) is not libzfs.ZFSDataset:
		raise ArgumentValidationError(1, 'zfs_segregate', libzfs.ZFSDataset)
	zvols = []
	datasets = []
	if zfs_dataset.properties['type'].value == 'volume':
		# since zvols do not have children lets just return now
		return datasets, [zfs_dataset]
	else:
		datasets = [zfs_dataset]
	for x, y in map(zfs_segregate, list(zfs_dataset.children)):
		datasets.extend(x)
		zvols.extend(y)
	return datasets, zvols


def main():
	global pp
	retry_timestamp = int(time.time())
	retry_counter = MAX_RETRY
	while retry_counter > 0:
		try:
			syslog.syslog(syslog.LOG_WARNING, "Starting FreeNAS monitoring...")

			# Load helpers
			pp = snmp.PassPersist(OID_BASE)
			pp.start(update_data, POOLING_INTERVAL)
		except KeyboardInterrupt:
			print("Exiting on user request.")
			sys.exit(0)
		except IOError as e:
			if e.errno == errno.EPIPE:
				syslog.syslog(syslog.LOG_INFO, "Snmpd had close pipe, exiting...")
				sys.exit(0)
			else:
				syslog.syslog(syslog.LOG_WARNING, "Updater thread has died: IOError: %s" % e)
		except Exception as e:
			syslog.syslog(syslog.LOG_WARNING, "Main thread has died: %s: %s" % (e.__class__.__name__, e))
		else:
			syslog.syslog(syslog.LOG_WARNING, "Updater thread has died: %s" % pp.error)

		syslog.syslog(syslog.LOG_WARNING, "Restarting monitoring in 15 sec...")
		time.sleep(15)

		# Errors frequency detection
		now = int(time.time())
		if (now - 3600) > retry_timestamp:  # If the previous error is older than 1H
			retry_counter = MAX_RETRY  # Reset the counter
		else:
			retry_counter -= 1  # Else countdown
		retry_timestamp = now
	syslog.syslog(syslog.LOG_ERR, "Too many retries, aborting!")
	sys.exit(1)


if __name__ == "__main__":
	main()
