#! /usr/bin/env python

"""\
%(prog)s [-o outfile] <yodafile1>[:<scale1>] <yodafile2>[:<scale1>] ...
  e.g. %(prog)s run1.yoda run2.yoda run3.yoda  (unweighted merging of three runs)
    or %(prog)s run1.yoda:2.0 run2.yoda:3.142  (weighted merging of two runs)

Merge analysis objects from multiple YODA files, combining the statistics of
objects whose names are found in multiple files. May be used either to merge
disjoint collections of data objects, or to combine multiple statistically
independent runs of the same data objects into one high-statistics run. Optional
scaling parameters may be given to rescale the weights of the objects on a
per-file basis before merging.

By default the output is written to stdout since we can't guess what would be
a good automatic filename choice! Use the -o option to provide an output filename.


IMPORTANT!
  This script is not meant to handle all run merging situations or data objects:
  there are limitations to what can be inferred from data objects alone. If you
  need to do something more complex than the common cases handled by this script,
  please write your own script / program to load and process the data objects.


SCATTERS (E.G. HISTOGRAM RATIOS) CAN'T BE MERGED

  Note that 'scatter' data objects, as opposed to histograms, cannot be merged
  by this tool since they do not preserve sufficient statistical
  information. The canonical example of this is a ratio plot: there are
  infinitely many combinations of numerator and denominator which could give the
  same ratio, and the result does not indicate anything about which of those
  infinite inputs is right (or the effects of correlations in the division).

  If you need to merge Scatter2D objects, you can write your own Python script
  or C++ program using the YODA interface, and apply whatever case-specific
  treatment is appropriate. By default the first such copy encountered will be
  returned as the 'merged' output, with no actual merging having been done.

NORMALIZED, UNNORMALIZED, OR A MIX?

  An important detail in histogram merging is whether a statistical treatment
  for normalized or unnormalized histograms should be used: in the former case
  the normalization scaling must be undone *before* the histograms are added
  together, and then re-applied afterwards. This script examines the ScaledBy
  attribute each histograms to determine if it has been normalized. We make the
  assumption that if ScaledBy exists (i.e. h.scaleW has been called) then the
  histogram is normalized and we normalize the resulting merged histogram to the
  weighted average of input norms; if there is no ScaledBy, we assume that the
  histogram is not normalised.

  This is not an infallible approach, but we believe is more robust than heuristics
  to determine whether norms are sufficiently close to be considered equal.
  In complicated situations you will again be better off writing your own
  script or program to do the merging: the merging machinery of this script is
  available directly in the yoda Python module.

See the source of this script (e.g. use 'less `which %(prog)s`) for more discussion.
"""

# MORE NOTES
#
# If all the input histograms with a particular path are found to have the same
# normalization, and they have ScaledBy attributes indicating that a histogram
# weight scaling has been applied in producing the input histograms, each
# histogram in that group will be first unscaled by their appropriate factor, then
# merged, and then re-normalized to the target value. Otherwise the weights from
# each histogram copy will be directly added together with no attempt to guess an
# appropriate normalization. The normalization guesses (and they are guesses --
# see below) are made *before* application of the per-file scaling arguments.
#
# IMPORTANT: note from the above that this script can't work out what to do
# re. scaling and normalization of output histograms from the input data files
# alone. It may be possible (although unlikely) that input histograms have the
# same normalization but are meant to be added directly. It may also be the case
# (and much more likely) that histograms which should be normalized to a common
# value will not trigger the appropriate treatment due to e.g. statistical
# fluctuations in each run's calculation of a cross-section used in the
# normalization. And anything more complex than a global scaling (e.g. calculation
# of a ratio or asymmetry) cannot be handled at all with a post-hoc scaling
# treatment. The --assume-normalized command line option will force all histograms
# to be treated as if they are normalized in the input, which can be useful if
# you know that all the output histograms are indeed of this nature. If they are
# not, it will go wrong: you have been warned!
#
# Please use this script as a template if you need to do something more specific.
#
# NOTE: there are many possible desired behaviours when merging runs, depending on
# the factors above as well as whether the files being merged are of homogeneous
# type, heterogeneous type, or a combination of both. It is tempting, therefore,
# to add a large number of optional command-line parameters to this script, to
# handle these cases. Experience from Rivet 1.x suggests that this is a bad idea:
# if a problem is of programmatic complexity then a command-line interface which
# attempts to solve it in general is doomed to both failure and unusability. Hence
# we will NOT add extra arguments for applying different merging weights or
# strategies based on analysis object path regexes, auto-identifying 'types' of
# run, etc., etc.: if you need to merge data files in such complex ways, please
# use this script as a template around which to write logic that satisfies your
# particular requirements.


import yoda, argparse, sys, math

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("INFILES", nargs="+", help="datafile1 datafile2 [...]")
parser.add_argument("-o", "--output", default="-", dest="OUTPUT_FILE", metavar="PATH",
                    help="write output to specified path")
parser.add_argument("-m", "--match", dest="MATCH", metavar="PATT", default=None,
                    help="only write out histograms whose path matches this regex")
parser.add_argument("-M", "--unmatch", dest="UNMATCH", metavar="PATT", default=None,
                    help="exclude histograms whose path matches this regex")
parser.add_argument('-v', '--verbose', action="store_const", const=2, default=1, dest='VERBOSITY',
                    help="print extra merging details")
parser.add_argument('-q', '--quiet', action="store_const", const=0, default=1, dest='VERBOSITY',
                    help="only print fatal errors")
parser.add_argument("--add", "--stack", action="store_true", default=False, dest="STACK",
                    help="DEPRECATED, USE 'yodastack' INSTEAD.")
args = parser.parse_args()

if args.STACK:
  cmd = "WARNING: deprecated option used! For simple stacking, please run the following command instead:\n  yodastack "
  cmd += " ".join([ arg for arg in sys.argv[1:] if '--add' not in arg and '--stack' not in arg ])
  sys.exit(cmd)

#def eq(a, b, tol=5e-2):
#    if a == 0 and b == 0:
#        return True
#    return 2*abs((a-b)/(abs(a)+abs(b))) < tol


## Loop over all input files
ntotal = len(args.INFILES)
aos_in, aos_tmp, sfs, sources = None, {}, {}, {}
for n, filename in enumerate(args.INFILES):

    ## Update the use on which file is being merged
    if args.VERBOSITY > 0:
        msg = "Merging data file {:s} [{:d}/{:d}]".format(filename, n+1, ntotal)
        sys.stdout.write(msg + "\n")

    ## Release old AOs before loading the new ones to reduce the peak memory usage
    del aos_in

    ## Put the next batch of incoming objects into a dict from each path to a list of histos and scalings
    aos_in = yoda.read(filename, True, args.MATCH, args.UNMATCH)

    ## Record internal (histo-scaling) for each AO
    for aopath, ao in aos_in.items():
        scaledby = float(ao.annotation("ScaledBy", 0.0))
        hscale = 1.0/scaledby if scaledby else None
        sfs.setdefault(aopath, []).append(hscale)

    ## Process each AO in the set from this file
    for aopath, ao in aos_in.items():

        ## Counter, Histo and Profile (i.e. Fillable) merging
        # TODO: Add a Fillable interface and use that for the type matching
        aotype = type(ao)
        if aotype in (yoda.Counter, yoda.Histo1D, yoda.Histo2D, yoda.Profile1D, yoda.Profile2D):

            ## Undo the normalization scaling if appropriate
            if sfs[aopath][n] is not None:
                ao.scaleW( sfs[aopath][-1] )

            ## Initialise or add to the aggregating ((de-)scaled) output histos
            if aopath not in aos_tmp:
                aos_tmp[aopath] = ao.clone()
            else:
                aotype_ref = type(aos_tmp[aopath])
                if aotype == aotype_ref:
                    aos_tmp[aopath] += ao
                else:
                    ## If type mismatches, skip this one and mark its weight contribution as null
                    sys.stderr.write("Unexpected type for analysis object '{}' from {}: {} vs expected {}\n".format(aopath, filename, aotype, aotype_ref))
                    sfs[aopath][-1] = None

        ## Scatter "merging"
        elif aotype in (yoda.Scatter1D,yoda.Scatter2D,yoda.Scatter3D):

            if args.VERBOSITY > 1:
                sys.stderr.write("WARNING: scatter object '{}' cannot be merged: will perform simple average\n".format(aopath))

            ## incremeant counter for average
            sfs.setdefault(aopath, []).append(1.0)

            ## Initialise or add to the aggregating ((de-)scaled) output histos
            if aopath not in aos_tmp:
                aos_tmp[aopath] = ao.clone()
                sources[aopath] = ao.variations()
            else:
                ## Retrieve dimensionality of the Scatter*D object
                dim = aos_tmp[aopath].dim()
                npoints = len(aos_tmp[aopath].points())
                for i in range(npoints):
                    val_i = aos_tmp[aopath].point(i).val(dim) + ao.point(i).val(dim)
                    aos_tmp[aopath].point(i).setVal(dim, val_i)
                    # update the values of the multiple error sources
                    for var in ao.variations():
                        ep_i = 0.
                        em_i = 0.
                        if var in sources[aopath]:
                            ep_i = aos_tmp[aopath].point(i).errs(dim,var)[0]
                            em_i = aos_tmp[aopath].point(i).errs(dim,var)[1]
                        else:
                            sources[aopath].append(var)
                        ep_i = math.sqrt(ep_i**2 + ao.point(i).errs(dim,var)[0]**2)
                        em_i = math.sqrt(em_i**2 + ao.point(i).errs(dim,var)[1]**2)
                        aos_tmp[aopath].point(i).setErrs(dim, (ep_i, em_i), var)
        ## Other data types (just warn, and write out the first object)
        else:
            if args.VERBOSITY > 0:
                sys.stderr.write("WARNING: Analysis object %s of type %s cannot be merged\n" % (aopath, str(aotype)))


## Rescale and output, after reduction
aos_out = {}
for aopath, ao in aos_tmp.items():

    ## Re-apply normalisation factors if needed
    #print(aopath, type(ao))
    sum_sfs = sum(f for f in sfs[aopath] if f)
    #print(sum_sfs)
    if sum_sfs:
        if type(ao) in (yoda.Counter, yoda.Histo1D, yoda.Histo2D, yoda.Profile1D, yoda.Profile2D):
            ao.scaleW(1.0 / sum_sfs)
        else:
            dim = ao.dim()
            ao.scale(dim, 1.0 / sum_sfs)

    ## Put in output container
    aos_out[aopath] = ao


## Write output
if args.VERBOSITY > 0:
    sys.stdout.write("Writing merged data to {}\n".format(args.OUTPUT_FILE))
yoda.writeYODA(aos_out, args.OUTPUT_FILE)
