#!/usr/bin/env python
# vim:se fileencoding=utf-8 :
# Fix EDID screen dimensions
# (c) 2018 Michał Górny
# Released under the terms of 2-clause BSD license

import argparse
import sys


def verify_checksum(data):
	"""
	Verify the EDID checksum -- that is, that sum of all bytes mod 0x100
	in the block equals to zero.

	@data specifies the data to validate.  It should be a bytes object
	or a compatible type.

	Raises an exception on verification failure.
	"""

	m = sum(data) % 0x100
	if m != 0:
		raise ValueError('EDID checksum mismatch: remainder is {}'
					     .format(m))


def update_checksum(data):
	"""
	Update the checksum at the last byte of data.

	@data specifies the data block to checksum.  It is modified
	in place, and should be a memoryview over bytearray.
	"""

	m = sum(data[:-1]) % 0x100
	data[-1] = 0x100 - m


def process_edid(f, visitor):
	"""
	Process the EDID data in file @f.  Does minimal verification
	of the EDID file, and calls methods of @visitor for specific parts
	of it.  The visitor methods are passed memoryview over bytearrays
	of EDID data and are allowed to modify it in place.

	See GetVisitor and SetVisitor for which methods are used.

	Returns bytes object with complete file data, potentially modified
	by visitors.
	"""

	outdata = b''

	# read the initial EDID structure
	data = memoryview(bytearray(f.read(128)))
	if len(data) < 128:
		raise ValueError('Truncated EDID structure')
	if data[:8] != b'\x00\xff\xff\xff\xff\xff\xff\x00':
		raise ValueError('EDID magic bytes not found')
	if data[18] != 1:
		raise ValueError('Unsupported EDID version: {}'.format(data[18]))
	verify_checksum(data)
	visitor.edid_structure(data)

	# visit descriptor blocks
	for desc in (data[54:72], data[72:90], data[90:108], data[108:126]):
		visitor.descriptor(desc)

	visitor.post_block(data)
	outdata += data

	# read extensions
	numext = data[126]
	while numext > 0:
		extdata = f.read(1)
		if len(extdata) == 0:
			raise ValueError('Expected extension block missing')
		if extdata == b'\x02':
			# CEA EDID
			print('CEA EDID found')
			extdata = memoryview(bytearray(extdata + f.read(127)))
			if len(extdata) < 128:
				raise ValueError('Truncated CEA EDID extension')
			verify_checksum(extdata)

			dtd_start = extdata[2]
			while dtd_start != 0:
				desc = extdata[dtd_start:dtd_start+18]
				if desc[:2] == b'\x00\x00':
					break
				visitor.descriptor(desc)
				dtd_start += len(desc)

			visitor.post_block(extdata)
			outdata += extdata
		else:
			raise NotImplementedError('Unsupported extension: {}'.format(extdata[0]))

		numext -= 1

	if f.read(1):
		raise ValueError('Data past last extension found')

	return outdata


class GetVisitor(object):
	"""
	Visitor class that prints dimensions from EDID file.
	"""

	def edid_structure(self, data):
		"""
		Process the initial EDID structure @data.  Prints rough
		dimension data from it (stored in cm).
		"""

		w, h = data[21:23]
		print('EDID structure: {} cm x {} cm'.format(w, h))

	def descriptor(self, data):
		"""
		Process detailed timing descriptor @data.  Prints dimension
		data from them (stored in mm).
		"""

		if data[:2] == b'\x00\x00':
			# not a timing descriptor
			return

		w = data[12] | ((data[14] & 0xf0) << 4)
		h = data[13] | ((data[14] & 0x0f) << 8)
		print('Detailed timing desc: {} mm x {} mm'.format(w, h))

	def post_block(self, data):
		"""
		Post-process block @data.  No-op.
		"""

		pass


class SetVisitor(object):
	"""
	Visitor class that replaces dimensions in EDID file with new ones.
	"""

	def __init__(self, new_dims):
		"""
		Instantiate the class.

		@new_dims specifies the new dimensions, in DimString format.
		"""

		self.new_dims = new_dims
		self.w_lsb = new_dims.width & 0xff
		self.h_lsb = new_dims.height & 0xff
		self.wh_msb = ((new_dims.width & 0xf00) >> 4
				      |(new_dims.height & 0xf00) >> 8)

	def edid_structure(self, data):
		"""
		Process the initial EDID structure @data.  Replaces the rough
		dimension data in it (to new dimensions rounded to 1 cm).
		"""

		data[21] = round(self.new_dims.width / 10)
		data[22] = round(self.new_dims.height / 10)
		print('EDID structure updated to {} cm x {} cm'.format(*data[21:23]))

	def descriptor(self, data):
		"""
		Process detailed timing descriptor @data.  Replaces
		the dimension data in it.
		"""

		if data[:2] == b'\x00\x00':
			# not a timing descriptor
			return

		data[12] = self.w_lsb
		data[13] = self.h_lsb
		data[14] = self.wh_msb
		print('Detailed timing desc updated to {} mm x {} mm'.format(
			self.new_dims.width, self.new_dims.height))

	def post_block(self, data):
		"""
		Post-process block @data.  Updates the checksum.
		"""

		update_checksum(data)


class DimString(object):
	"""
	Type used to get dimension string (WIDTHxHEIGHT) from command-line.
	"""

	def __init__(self, val):
		"""
		Process the string @val.
		"""
		self.width, self.height = (int(x) for x in val.split('x', 1))

	def __repr__(self):
		return 'DimString("{}x{}")'.format(self.width, self.height)


def main(prog_name, *argv):
	"""
	Main program function with argument parsing.
	
	@prog_name should be argv[0].  The remaining arguments should be
	program arguments argv[1:].
	
	Returns numeric exit status (or raises an exception).
	"""

	argp = argparse.ArgumentParser(prog=prog_name)
	act = argp.add_mutually_exclusive_group()
	act.add_argument('-g', '--get', action='store_true',
					 help='Get screen dimensions from EDID')
	act.add_argument('-s', '--set', type=DimString,
					 metavar='WIDTHxHEIGHT',
					 help='Set screen dimensions (in mm)')
	argp.add_argument('file',
				      help='Binary EDID dump to process/modify')
	vals = argp.parse_args(argv)

	if vals.get:
		mode = 'rb'
		visitor = GetVisitor()
	elif vals.set is not None:
		mode = 'r+b'
		visitor = SetVisitor(vals.set)
	else:
		argp.error('No action specified')

	with open(vals.file, mode) as f:
		outdata = process_edid(f, visitor)

		if vals.set is not None:
			f.seek(0)
			f.write(outdata)

	return 0


if __name__ == '__main__':
	sys.exit(main(*sys.argv))
