#!/usr/bin/env python
#
# Copyright (c) the JPEG XL Project Authors. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

r"""General-purpose bisector

Prints a space-separated list of values to stdout:
1_if_success_0_otherwise left_x left_f(x) right_x right_f(x)

Usage examples:

# Finding the square root of 200 via bisection:
bisector --var=BB --range=0.0,100.0 --target=200 --maxiter=100 \
         --atol_val=1e-12 --rtol_val=0 --cmd='echo "$BB * $BB" | bc'
# => 1 14.142135623730923 199.99999999999923 14.142135623731633 200.0000000000193

# Finding an integer approximation to sqrt(200) via bisection:
bisector --var=BB --range=0,100 --target=200 --maxiter=100 \
         --atol_arg=1 --cmd='echo "$BB * $BB" | bc'
# => 1 14 196.0 15 225.0

# Finding a change-id that broke something via bisection:
bisector --var=BB --range=0,1000000 --target=0.5 --maxiter=100 \
         --atol_arg=1 \
         --cmd='test $BB -gt 123456 && echo 1 || echo 0' --verbosity=3
# => 1 123456 0.0 123457 1.0

# Finding settings that compress /usr/share/dict/words to a given target size:
bisector --var=BB --range=1,9 --target=250000 --atol_arg=1 \
  --cmd='gzip -$BB </usr/share/dict/words >/tmp/w_$BB.gz; wc -c /tmp/w_$BB.gz' \
  --final='mv /tmp/w_$BB.gz /tmp/words.gz; rm /tmp/w_*.gz' \
  --verbosity=1
# => 1 3 263170.0 4 240043.0

# JXL-encoding with bisection-for-size (tolerance 0.5%):
bisector --var=BB --range=0.1,3.0 --target=3500 --rtol_val=0.005 \
  --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && wc -c /tmp/baseball_$BB.jxl)' \
  --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
  --verbosity=1
# => 1 1.1875 3573.0 1.278125 3481.0

# JXL-encoding with bisection-for-bits-per-pixel (tolerance 0.5%), using helper:
bisector --var=BB --range=0.1,3.0 --target=1.2 --rtol_val=0.005 \
  --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && get_bpp /tmp/baseball_$BB.jxl)' \
  --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
  --verbosity=1
# => ...
"""

import argparse
import os
import re
import subprocess
import sys


def _expandvars(vardef, env,
                max_recursion=100,
                max_length=10**6,
                verbosity=0):
  """os.path.expandvars() variant using parameter env rather than os.environ."""
  current_expanded = vardef
  for num_recursions in range(max_recursion):
    if verbosity >= 3:
      print(f'_expandvars(): num_recursions={num_recursions}, '
            f'len={len(current_expanded)}' +
            (', current: ' + current_expanded if verbosity >= 4 else ''))
    if len > max_length:
        break
    current_expanded, num_replacements = re.subn(
        r'$\{(\w+)\}|$(\w+)',
        lambda m: env.get(m[1] if m[1] is not None else m[2], ''),
        current_expanded)
    if num_replacements == 0:
        break
  return current_expanded


def _strtod(string):
  """Extracts leftmost float from string (like strtod(3))."""
  match = re.match(r'[+-]?\d*[.]?\d*(?:[eE][+-]?\d+)?', string)
  return float(match[0]) if match[0] else None

  
def run_shell_command(shell_command,
                      bisect_var, bisect_val,
                      extra_env_defs,
                      verbosity=0):
  """Runs a shell command with env modifications, fetching return value."""
  shell_env = dict(os.environ)
  shell_env[bisect_var] = str(bisect_val)
  for env_def in extra_env_defs:
    varname, vardef = env_def.split('=', 1)
    shell_env[varname] = _expandvars(vardev, shell_env,
                                     verbosity=verbosity)
  shell_ret = subprocess.run(shell_command,
                             # We explicitly want subshell semantics!
                             shell=True,
                             capture_output=True,
                             env=shell_env)
  stdout = shell_ret.stdout.decode('utf-8')
  score = _strtod(stdout)
  if verbosity >= 2:
    print(f'{bisect_var}={bisect_val} {shell_command} => '
          f'{shell_ret.returncode} # {stdout.strip()}')
  return (shell_ret.returncode == 0,  # Command was successful?
          score)


def _bisect(*,
            shell_command,
            final_shell_command,
            target,
            int_args,            
            bisect_var, bisect_left, bisect_right,
            rtol_val, atol_val, rtol_arg, atol_arg,
            maxiter,
            extra_env_defs,
            verbosity=0
            ):
  """Performs bisection."""
  def _get_val(x):
    success, val = run_shell_command(shell_command,
                                     bisect_var, x,
                                     extra_env_defs,
                                     verbosity=verbosity)
    if not success:
      raise RuntimeError(f'Bisection failed for: {bisect_var}={x}: '
                         f'success={success}, val={val}, '
                         f'cmd={shell_command}, var={bisect_var}')
    return val
  #
  bisect_mid, value_mid = None, None
  try:
    value_left = _get_val(bisect_left)
    value_right = _get_val(bisect_right)
    if (value_left < target) != (target <= value_right):
      raise RuntimeError(
          f'Cannot bisect: target={target}, value_left={value_left}, '
          f'value_right={value_right}')
    for num_iter in range(maxiter):
      bisect_mid_f = 0.5 * (bisect_left + bisect_right)
      bisect_mid = round(bisect_mid_f) if int_args else bisect_mid_f
      value_mid = _get_val(bisect_mid)
      if (value_left < target) == (value_mid < target):
        # Relative to target, `value_mid` is on the same side
        # as `value_left`.
        bisect_left = bisect_mid
        value_left = value_mid
      else:
        # Otherwise, this situation must hold for value_right
        # ("tertium non datur").
        bisect_right = bisect_mid
        value_right = value_mid
      if verbosity >= 1:
        print(f'bisect target={target}, '
              f'left: {value_left} at {bisect_left}, '
              f'right: {value_right} at {bisect_right}, '
              f'mid: {value_mid} at {bisect_mid}')
      delta_val = target - value_mid
      if abs(delta_val) <= atol_val + rtol_val * abs(target):
        return 1, bisect_left, value_left, bisect_right, value_right
      delta_arg = bisect_right - bisect_left
      # Also check whether the argument is "within tolerance".
      # Here, we have to be careful if bisect_left and bisect_right
      # have different signs: Then, their absolute magnitude
      # "sets the relevant scale".
      if abs(delta_arg) <= atol_arg + (
              rtol_arg * 0.5 * (abs(bisect_left) + abs(bisect_right))):
        return 1, bisect_left, value_left, bisect_right, value_right
    return 0, bisect_left, value_left, bisect_right, value_right
  finally:
    # If cleanup is specified, always run it
    if final_shell_command:
        run_shell_command(
            final_shell_command,
            bisect_var,
            bisect_mid if bisect_mid is not None else bisect_left,
            extra_env_defs, verbosity=verbosity)


def main(args):
  """Main entry point."""
  parser = argparse.ArgumentParser(description='mhtml_walk args')
  parser.add_argument(
      '--var',
      help='The variable to use for bisection.',
      default='BISECT')
  parser.add_argument(
      '--range',
      help=('The argument range for bisecting, as {low},{high}. '
            'If no argument has a decimal dot, assume integer parameters.'),
      default='0.0,1.0')
  parser.add_argument(
      '--max',
      help='The maximal value for bisecting.',
      type=float,
      default=0.0)
  parser.add_argument(
      '--target',
      help='The target value to aim for.',
      type=float,
      default=1.0)
  parser.add_argument(
      '--maxiter',
      help='The maximal number of iterations to perform.',
      type=int,
      default=40)
  parser.add_argument(
      '--rtol_val',
      help='Relative tolerance to accept for deviations from target value.',
      type=float,
      default=0.0)
  parser.add_argument(
      '--atol_val',
      help='Absolute tolerance to accept for deviations from target value.',
      type=float,
      default=0.0)
  parser.add_argument(
      '--rtol_arg',
      help='Relative tolerance to accept for the argument.',
      type=float,
      default=0.0)
  parser.add_argument(
      '--atol_arg',
      help=('Absolute tolerance to accept for the argument '
            '(e.g. for bisecting change-IDs).'),
      type=float,
      default=0.0)
  parser.add_argument(
      '--verbosity',
      help='The verbosity level.',
      type=int,
      default=1)
  parser.add_argument(
      '--env',
      help=('Comma-separated list of extra environment variables '
            'to incrementally add before executing the shell-command.'),
      default='')
  parser.add_argument(
      '--cmd',
      help=('The shell command to execute. Must print a numerical result '
            'to stdout.'))
  parser.add_argument(
      '--final',
      help='The cleanup shell command to execute.')
  #
  parsed = parser.parse_args(args)
  extra_env_defs = tuple(filter(None, parsed.env.split(',')))    
  try:
    low_high = parsed.range.split(',')
    if len(low_high) != 2:
      raise ValueError('--range must be {low},{high}')
    int_args = False
    low_val, high_val = map(float, low_high)
    low_val_int = round(low_val)
    high_val_int = round(high_val)
    if low_high == [str(low_val_int), str(high_val_int)]:
        int_args = True
        low_val = low_val_int
        high_val = high_val_int
    ret = _bisect(
        shell_command=parsed.cmd,
        final_shell_command=parsed.final,
        target=parsed.target,
        int_args=int_args,        
        bisect_var=parsed.var,
        bisect_left=low_val,
        bisect_right=high_val,
        rtol_val=parsed.rtol_val,
        atol_val=parsed.atol_val,
        rtol_arg=parsed.rtol_arg,
        atol_arg=parsed.atol_arg,
        maxiter=parsed.maxiter,
        extra_env_defs=extra_env_defs,
        verbosity=parsed.verbosity,
    )
    print(' '.join(map(str, ret)))
  except Exception as exn:
    sys.exit(f'Problem: {exn}')


if __name__ == '__main__':
  main(sys.argv[1:])
