#! /usr/bin/env python

"""\
%(prog)s [-o outfile] <yodafile1>[:<scale1>] <yodafile2>[:<scale1>] ...
  e.g. %(prog)s run1.yoda run2.yoda run3.yoda  (unweighted stacking of three runs)
    or %(prog)s run1.yoda:2.0 run2.yoda:3.142  (weighted stacking of two runs)

Stack analysis objects from multiple YODA files, combining the statistics of
objects whose names are found in multiple files. May be used either to add
disjoint collections of data objects. Optional scaling parameters may be given 
to rescale the weights of the objects on a per-file basis before merging.

By default the output is written to stdout since we can't guess what would be
a good automatic filename choice! Use the -o option to provide an output filename.


IMPORTANT!
  This script is not meant to handle all run merging situations or data objects:
  there are limitations to what can be inferred from data objects alone. If you
  need to do something more complex than the common cases handled by this script,
  please write your own script / program to load and process the data objects.
"""

import yoda, argparse, sys, math

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("INFILES", nargs="+", help="datafile1 datafile2 [...]")
parser.add_argument("-o", "--output", default="-", dest="OUTPUT_FILE", metavar="PATH",
                    help="write output to specified path")
parser.add_argument("-m", "--match", dest="MATCH", metavar="PATT", default=None,
                    help="only write out histograms whose path matches this regex")
parser.add_argument("-M", "--unmatch", dest="UNMATCH", metavar="PATT", default=None,
                    help="exclude histograms whose path matches this regex")
parser.add_argument('-v', '--verbose', action="store_const", const=2, default=1, dest='VERBOSITY',
                    help="print extra merging details")
parser.add_argument('-q', '--quiet', action="store_const", const=0, default=1, dest='VERBOSITY',
                    help="only print fatal errors")
args = parser.parse_args()


## Loop over all input files
ntotal = len(args.INFILES)
aos_in, aos_out, sources = None, {}, {}
for n, fa in enumerate(args.INFILES):

    filename, fileweight = fa, 1.0
    if ":" in fa:
        try:
            filename, fileweight = fa.rsplit(":", 1)
            fileweight = float(fileweight)
        except:
            sys.stderr.write("Error processing arg '%s' with file:scale format\n" % fa)

    ## Update the use on which file is being merged
    if args.VERBOSITY > 0:
        msg = "Stacking data file {:s} with weight {:f} [{:d}/{:d}]".format(filename, fileweight, n+1, ntotal)
        sys.stdout.write(msg + "\n")

    ## Release old AOs before loading the new ones to reduce the peak memory usage
    del aos_in

    ## Put the next batch of incoming objects into a dict from each path to a list of histos and scalings
    aos_in = yoda.read(filename, True, args.MATCH, args.UNMATCH)

    ## Process each AO in the set from this file
    for aopath, ao in aos_in.items():

        ## Counter, Histo and Profile (i.e. Fillable) merging
        # TODO: Add a Fillable interface and use that for the type matching
        aotype = type(ao)
        if aotype in (yoda.Counter, yoda.Histo1D, yoda.Histo2D, yoda.Profile1D, yoda.Profile2D):

            ## Apply file-based scaling, but not to 'raw' objects
            if '/RAW' not in aopath:
                ao.scaleW( fileweight )

            ## Initialise or add to the aggregating ((de-)scaled) output histos
            if aopath not in aos_out:
                aos_out[aopath] = ao.clone()
            else:
                aotype_ref = type(aos_out[aopath])
                if aotype == aotype_ref:
                    aos_out[aopath] += ao
                elif args.VERBOSITY > 0:
                    ## If type mismatches, skip this one and mark its weight contribution as null
                    sys.stderr.write("Unexpected type for analysis object '{}' from {}: {} vs expected {}\n".format(aopath, filename, aotype, aotype_ref))

        ## Scatter "stacking"
        elif aotype in (yoda.Scatter1D,yoda.Scatter2D,yoda.Scatter3D):
            ## Retrieve dimensionality of the Scatter*D object
            dim = ao.dim()
            npoints = len(ao.points())
            sf = fileweight if '/RAW' not in aopath else 1.0

            ## clone first one and apply file-based scaling
            if aopath not in aos_out:
                aos_out[aopath] = ao.clone()
                sources[aopath] = ao.variations()
                aos_out[aopath].scale(dim, sf)
            else:
                ## stack additional scatters, applying file-based scaling for each one
                for i in range(npoints):
                    # update central value
                    val_i = aos_out[aopath].point(i).val(dim)
                    val_i += sf * ao.point(i).val(dim)
                    aos_out[aopath].point(i).setVal(dim, val_i)
                    # update the values of the multiple error sources
                    for var in ao.variations():
                        ep_i = 0.
                        em_i = 0.
                        if var in sources[aopath]:
                            ep_i = aos_out[aopath].point(i).errs(dim,var)[0]
                            em_i = aos_out[aopath].point(i).errs(dim,var)[1]
                        else:
                            sources[aopath].append(var)
                        ep_i = math.sqrt(ep_i**2 + (sf*ao.point(i).errs(dim,var)[0])**2)
                        em_i = math.sqrt(em_i**2 + (sf*ao.point(i).errs(dim,var)[1])**2)
                        aos_out[aopath].point(i).setErrs(dim, (ep_i, em_i), var)

        ## Other data types (just warn, and write out the first object)
        else:
            if args.VERBOSITY > 0:
                sys.stderr.write("WARNING: Analysis object %s of type %s cannot be stacked\n" % (aopath, str(aotype)))


## Write output
if args.VERBOSITY > 0:
    sys.stdout.write("Writing stacked data to {}\n".format(args.OUTPUT_FILE))
yoda.writeYODA(aos_out, args.OUTPUT_FILE)

