#!/usr/bin/python3
#
#       prime-select
#
#       Copyright 2013 Canonical Ltd.
#       Author: Alberto Milone <alberto.milone@canonical.com>
#
#       Script to switch between NVIDIA and Intel graphics driver libraries.
#
#       Usage:
#           prime-select   nvidia|intel|on-demand|query
#           nvidia:    switches to NVIDIA's version of libGL.so
#           on-demand: load NVIDIA driver, and on-demend for others
#           intel: switches to the open-source version of libGL.so
#           query: checks which version is currently active and writes
#                  "nvidia", "intel", "on-demand" or "unknown" to the
#                  standard output
#
#       Permission is hereby granted, free of charge, to any person
#       obtaining a copy of this software and associated documentation
#       files (the "Software"), to deal in the Software without
#       restriction, including without limitation the rights to use,
#       copy, modify, merge, publish, distribute, sublicense, and/or sell
#       copies of the Software, and to permit persons to whom the
#       Software is furnished to do so, subject to the following
#       conditions:
#
#       The above copyright notice and this permission notice shall be
#       included in all copies or substantial portions of the Software.
#
#       THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
#       EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
#       OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
#       NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
#       HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
#       WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#       FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
#       OTHER DEALINGS IN THE SOFTWARE.

import glob
import os
import sys
import re
import subprocess
import shutil
import itertools
import time

from copy import deepcopy
from subprocess import Popen, PIPE, CalledProcessError


class Switcher(object):

    def __init__(self):
        self._power_profile_path = '/etc/prime-discrete'
        self._grub_path = '/etc/default/grub'
        self._grub_cmdline_start = 'GRUB_CMDLINE_LINUX_DEFAULT='
        self._old_blacklist_file = '/etc/modprobe.d/blacklist-nvidia.conf'
        self._blacklist_file = '/lib/modprobe.d/blacklist-nvidia.conf'
        self._nvidia_kms_file = '/lib/modprobe.d/nvidia-kms.conf'
        self._nvidia_runtimepm_file = '/lib/modprobe.d/nvidia-runtimepm.conf'
        self._gdm_conf_file = '/etc/gdm3/custom.conf'

    def _get_profile(self):

        try:
            settings = open(self._power_profile_path, 'r')
        except:
            return 'unknown'

        config = settings.read().strip()
        if config == 'on':
            return 'nvidia'
        elif config == "on-demand":
            return 'on-demand'
        else:
            return 'intel'

    def print_profile(self):
        profile = self._get_profile()
        if profile == 'unknown':
            return False

        print('%s' % profile)
        return True

    def _write_profile(self, profile):
        if profile == 'intel':
            nvidia_power = 'off'
        elif profile == "on-demand":
            nvidia_power = "on-demand"
        elif profile == 'nvidia':
            nvidia_power = 'on'
        else:
            return False

        # Write the settings to the file
        settings = open(self._power_profile_path, 'w')
        settings.write('%s\n' % nvidia_power)
        settings.close()

    def _is_laptop(self):
        # Chassis types:
        #
        # 01 Other
        # 02 Unknown
        # 03 Desktop
        # 04 Low Profile Desktop
        # 05 Pizza Box
        # 06 Mini Tower
        # 07 Tower
        # 08 Portable
        # 09 Laptop
        # 10 Notebook
        # 11 Hand Held
        # 12 Docking Station
        # 13 All In One
        # 14 Sub Notebook
        # 15 Space-saving
        # 16 Lunch Box
        # 17 Main Server Chassis
        # 18 Expansion Chassis
        # 19 Sub Chassis
        # 20 Bus Expansion Chassis
        # 21 Peripheral Chassis
        # 22 RAID Chassis
        # 23 Rack Mount Chassis
        # 24 Sealed-case PC
        # 25 Multi-system
        # 26 CompactPCI
        # 27 AdvancedTCA
        # 28 Blade
        # 29 Blade Enclosing
        # 30 Tablet
        # 31 Convertible
        # 32 Detachable
        # 33 IoT Gateway
        # 34 Embedded PC
        # 35 Mini PC
        # 36 Stick PC
        path = '/sys/devices/virtual/dmi/id/chassis_type'

        if os.path.isfile(path):
            with open(path, 'r') as f:
                chassis_type = f.read()
                f.close()
                if chassis_type:
                    chassis_type = int(chassis_type.strip())
                else:
                    return -1

                if chassis_type in (8, 9, 10, 31):
                    return 1
                else:
                    return 0


    def _get_bootvga_card(self):
        bootvga = ''
        cards = glob.glob('/sys/class/drm/card[0-9]')
        for card in cards:
            bootvga_path = os.path.join(card, 'device/boot_vga')
            try:
                with open(bootvga_path, 'r') as f:
                    t = f.read()
                    f.close()
                    if t:
                        if int(t.strip()) > 0:
                            bootvga = card
                            break
            except:
                pass
        return bootvga

    def _get_card_vendor(self, card):
        vendor_path = os.path.join(card, 'device/vendor')
        with open(vendor_path, 'r') as f:
            t = f.read()
            f.close()
            return t.strip()
        return ''

    def _has_integrated_gpu(self):
        status = False;

        path = '/var/lib/ubuntu-drivers-common/last_gfx_boot'
        if os.path.isfile(path):
            with open(path, 'r') as f:
                t = f.read()
                if t.find('8086') != -1 or t.lower().find('1002') != -1:
                    status = True
                f.close()
        else:
            card = self._get_bootvga_card()
            if card:
                vendor = self._get_card_vendor(card)
                if t.find('8086') != -1 or t.lower().find('1002') != -1:
                    status = True

        return status

    def _supports_runtimepm(self):
        return os.path.isfile('/run/nvidia_runtimepm_supported')

    def _is_runtimepm_enabled(self):
        return os.path.isfile('/run/nvidia_runtimepm_enabled')

    def _enable_runtimepm(self):
        print('Writing %s' % self._nvidia_runtimepm_file)
        pm_fd = open(self._nvidia_runtimepm_file, 'w')
        pm_fd.write('options nvidia \"NVreg_DynamicPowerManagement=0x02\"\n')
        pm_fd.close()

    def _disable_runtimepm(self):
        try:
            print('Deleting %s' % self._nvidia_runtimepm_file)
            os.unlink(self._nvidia_runtimepm_file)
        except:
            pass

    def enable_profile(self, profile):
        if not self._has_integrated_gpu():
            sys.stderr.write('Error: no integrated GPU detected.\n')
            return False

        sys.stdout.write('Info: selecting the %s profile\n' % (profile))

        self._backup_grub_config()

        if profile == 'nvidia':
            # Always allow enabling nvidia
            # (No need to check if nvidia is available)
            self._enable_nvidia()
        elif profile == "on-demand":
            # Enable RTD3 only on laptops
            if self._is_laptop() > 0:
                if self._supports_runtimepm():
                    self._enable_runtimepm()
            self._enable_kms()
            self._disable_nvidia(keep_nvidia_modules=True)
        else:
            # Make sure that the installed packages support PRIME
            #if not self._supports_prime():
            #    sys.stderr.write('Error: the installed packages do not support PRIME\n')
            #    return False
            self._disable_nvidia()

        # Write the settings to the config file
        self._write_profile(profile)

        return True

    def _disable_nvidia(self, keep_nvidia_modules=False):
        try:
            os.unlink(self._old_blacklist_file)
        except:
            pass

        if keep_nvidia_modules:
            try:
                os.unlink(self._blacklist_file)
            except:
                pass
        else:
            self._blacklist_nvidia()

        # Update the initramfs
        self._update_initramfs()

    def _enable_nvidia(self):
        try:
            os.unlink(self._old_blacklist_file)
        except:
            pass

        self._disable_runtimepm()

        try:
            os.unlink(self._blacklist_file)
        except:
            pass

        # Create configuration file so that users can enable
        # KMS easily.
        # modeset is off by default.
        if not os.path.isfile(self._nvidia_kms_file):
            self._enable_kms()

        # Update the initramfs
        self._update_initramfs()

    def _blacklist_nvidia(self):
        blacklist_text = '''# Do not modify
# This file was generated by nvidia-prime
blacklist nvidia
blacklist nvidia-drm
blacklist nvidia-modeset
alias nvidia off
alias nvidia-drm off
alias nvidia-modeset off'''
        blacklist_fd = open(self._blacklist_file, 'w')
        blacklist_fd.write(blacklist_text)
        blacklist_fd.close()

    def _write_kms_settings(self, value):
        # This is actually disabled now, but it can be enabled
        # by users with a simple change.
        kms_text = '''# This file was generated by nvidia-prime
# Set value to 1 to enable modesetting
options nvidia-drm modeset=%d''' % (value)
        kms_fd = open(self._nvidia_kms_file, 'w')
        kms_fd.write(kms_text)
        kms_fd.close()

    def _enable_kms(self):
        self._write_kms_settings(1)

    def _disable_kms(self):
        self._write_kms_settings(0)

    def _add_boot_params(self, pattern, path, params):
        it = 0
        arg_found = False

        with open(path, 'r+') as f:
            t = f.read()
            f.seek(0)
            for line in t.split('\n'):
                if line.startswith(pattern):
                    boot_args = line.replace(pattern, '').replace('"', '')
                    boot_args_list = boot_args.split(' ')
                    final_boot_args = deepcopy(boot_args_list)

                    for key, value in params.items():
                        target_param = '%s=%s' % (key, value)
                        for i, arg in enumerate(boot_args_list):
                            if key in arg:
                                arg_found = True
                                final_boot_args[i] = '%s' % (target_param)
                        if not arg_found:
                            final_boot_args.append(target_param)
                        else:
                            arg_found = False
                    new_line = '%s"%s"' % (pattern, ' '.join(final_boot_args))
                    f.write('%s%s' % ((it > 0 and '\n' or ''), new_line))
                else:
                    f.write('%s%s' % ((it > 0 and '\n' or ''), line))
                it +=1
            f.truncate()

    def _remove_boot_params(self, pattern, path, params):
        it = 0
        arg_found = False

        with open(path, 'r+') as f:
            t = f.read()
            f.seek(0)
            for line in t.split('\n'):
                if line.startswith(pattern):
                    boot_args = line.replace(pattern, '').replace('"', '')
                    boot_args_list = boot_args.split(' ')
                    final_boot_args = deepcopy(boot_args_list)

                    for key in params:
                        for i, arg in enumerate(boot_args_list):
                            if key in arg:
                                final_boot_args[i] = ''
                    final_boot_args = list(filter(bool, final_boot_args))
                    new_line = '%s"%s"' % (pattern, ' '.join(final_boot_args))
                    f.write('%s%s' % ((it > 0 and '\n' or ''), new_line))
                else:
                    f.write('%s%s' % ((it > 0 and '\n' or ''), line))
                it +=1
            f.truncate()


    def _find_connected_connectors(self, card):
        connectors = glob.glob('/sys/class/drm/%s-*' % (card))
        connected_connectors = []
        for connector in connectors:
            path = '%s/status' % connector
            with open(path, 'r') as f:
                t = f.read()
                if t.strip() == 'connected':
                    connected_connectors.append(connector)
                f.close()
        return connected_connectors

    def _get_boot_params_from_phantom_vga_connectors(self):
        params = []
        connectors = self._find_connected_connectors('card1')
        for connector in connectors:
            if 'vga' in connector.lower():
                conn = connector.replace('/sys/class/drm/card1-', '').replace('-', '')
                param = 'video=%s:d' % conn
                params.append(param)
        return params

    def _update_initramfs(self):
        # Create spinner to give feed back on the
        # operation
        spinner = itertools.cycle ( ['-', '/', '|', '\\'])
        proc = subprocess.Popen(['update-initramfs', '-u'],stdout=subprocess.PIPE)
        print('Updating the initramfs. Please wait for the operation to complete:')

        # Check if process is still running
        while proc.poll()==None:
            try:
                # Print the spinner
                sys.stdout.write(spinner.__next__())
                sys.stdout.flush()
                sys.stdout.write('\b')
                time.sleep(0.2)
            except BrokenPipeError:
                return False

        print('Done')

        # Print out the output
        output=proc.communicate()[0]

    def _update_grub(self):
        subprocess.call(['update-grub'])

    def _backup_grub_config(self):
        destination = '%s.prime-backup' % self._grub_path
        if not os.path.isfile(destination):
            shutil.copyfile(self._grub_path, destination)


def check_root():
    if not os.geteuid() == 0:
        sys.stderr.write("This operation requires root privileges\n")
        exit(1)

def handle_query_error():
    sys.stderr.write("Error: no profile can be found\n")
    exit(1)

def usage():
    sys.stderr.write("Usage: %s nvidia|intel|on-demand|query\n" % (sys.argv[0]))

if __name__ == '__main__':
    try:
        arg = sys.argv[1]
    except IndexError:
        arg = None

    if len(sys.argv[1:]) != 1:
        usage()
        exit(1)

    switcher = Switcher()

    if arg in ['intel', 'nvidia', 'on-demand']:
        check_root()
        switcher.enable_profile(arg)
    elif arg == 'query':
        if not switcher.print_profile():
            handle_query_error()
    else:
        usage()
        sys.exit(1)

    exit(0)
