import optparse
import os as os
from scipy import *
from pylab import *
##### PLOT FUNCTION
def plotres(XY,GT,CN,freq,T,DD,cnv):
    gtcol= ["y.","r.","g.","b."]
    GTL  = ["0","AA","AB","BB"]
    XY   = array(XY,'float')
    X    = XY[range(0,len(XY),2)]
    Y    = XY[range(1,len(XY),2)]
    GT   = array(GT,'float')
    CN   = array(CN,'float')
    index= ((X>=0)&(Y>=0)).nonzero()[0]
    if len(index)==0:
        return -1
    iA   = (CN[index]>=cnv[1]).nonzero()[0]
    iD   = ((CN[index]<=cnv[0])&(CN[index]>0)).nonzero()[0]
    iDD  = (CN[index]==0).nonzero()[0]
    idip = ((CN[index]>cnv[0])&(CN[index]<cnv[1])).nonzero()[0]
    if (CN[index]>=0).nonzero()[0].shape[0]>0:
        f    = float(len(iA)+len(iD)+len(iDD))/(CN[index]>=0).nonzero()[0].shape[0]
    else:
        f    = 0
    if f>=freq:
        X = X[index]
        Y = Y[index]
        GT= GT[index]
        CN= CN[index]
        r = figure(dpi=250,figsize=[15,8])
        r = subplot(121)
        r = plot(X,Y, 'k.',markersize=12)
        r = hold(True)
        L = ["uncall"]
        for i in range(4):
            ii=(GT==i).nonzero()[0]
            if len(ii)>0:
                r = plot(X[ii],Y[ii], gtcol[i],markersize=12)
                L.append(GTL[i])
        r = grid()
        r = xlim(-0.01,1.1*max(concatenate((X,Y))))
        r = ylim(-0.01,1.1*max(concatenate((X,Y))))
        r = title(T[0])
        r = legend(L)
        r = xlabel("X")
        r = ylabel("Y")
        r = subplot(122)
        kk= (CN>=0).nonzero()[0]
        CN[(CN>3).nonzero()[0]] = 3
        r = scatter(concatenate((array([-1,-1]),X[kk])),concatenate((array([-1,-1]),Y[kk])), s=28, c=concatenate((array([0,3]),CN[kk])), marker='o', cmap='jet', norm=None)
        r = hold(True)
        #plot(X[iDD],Y[iDD],'b.',markersize=6)
        #plot(X[iD],Y[iD],'c.',markersize=6)
        #plot(X[idip],Y[idip],marker='.',markerfacecolor='orange',markersize=6,linewidth=0)
        #plot(X[iA],Y[iA],marker='.',markerfacecolor='red',markersize=6,linewidth=0)
        #kk= (CN<0).nonzero()[0]
        #if len(kk)>0:
        #    plot(X[kk],Y[kk],'bx')
        r = colorbar()
        r = grid()
        r = xlim(-0.01,1.1*max(concatenate((X,Y))))
        r = ylim(-0.01,1.1*max(concatenate((X,Y))))
        r = title(T[1])
        r = xlabel("X")
        r = ylabel("Y")
        r = savefig(DD)
        r = close()
        return 1
    else:
        return 0
##### PARSER
desc = "This program generates SNP and CNV genotype plots from GStream output data."
parser = optparse.OptionParser(description=desc)
parser.add_option("-p", "--pref", dest="pref", help="prefix for intensity, SNP, CNV and log files")
parser.add_option("-i", "--int", dest="fxy", help="intensity file (required if prefix not provided)")
parser.add_option("-s", "--snp", dest="fsnp", help="SNP genotype file (required if prefix not provided)")
parser.add_option("-c", "--cnv",dest="fcnv", help="CNV genotype file (required if prefix not provided)")
parser.add_option("-L", "--log",dest="flog", help="GStream log file (required if prefix not provided)")
parser.add_option("-d", "--dir",dest="dir", help="directory to save figures (required)")
parser.add_option("-f", "--freq", dest="f", default=0.02, type='float', help="lower frequency threshold for plotting SNP probe (optional)")
parser.add_option("-A", "--tamp", dest="TA", default=2.8, type='float', help="score threshold for amplifications (optional)")
parser.add_option("-D", "--tdel", dest="TD", default=1.6, type='float', help="score threshold for deletions (optional)")
(opts, args) = parser.parse_args()

if opts.pref is None:
    if opts.fxy is None:
        print "\nIntensity file is missing!!\n"
        parser.print_help()
        exit(-1)
    if opts.fsnp is None:
        print "\nSNP genotype file is missing!!\n"
        parser.print_help()
        exit(-1)
    if opts.fcnv is None:
        print "\nCNV genotype file is missing!!\n"
        parser.print_help()
        exit(-1)
else:
    opts.fxy = opts.pref+".txt"
    opts.fsnp= opts.pref+".snp"
    opts.fcnv= opts.pref+".cnv"
    opts.flog= opts.pref+".log"
if opts.dir is None:
    print "\nFigure directory is missing!!\n"
    parser.print_help()
    exit(-1)
##### VERIFY
DIRS = os.listdir("")
if ((opts.fxy in DIRS)==False):
    print "Intensity file "+opts.fxy+" not found..."
    exit(-1)
if ((opts.fsnp in DIRS)==False):
    print "SNP genotype file "+opts.fsnp+" not found..."
    exit(-1)
if ((opts.fcnv in DIRS)==False):
    print "CNV genotype file "+opts.fcnv+" not found..."
    exit(-1)
if ((opts.flog in DIRS)==False):
    print "Log file "+opts.flog+" not found..."
    exit(-1)
if ((opts.f>=1)|(opts.f<0)):
    print "Frequency value not valid... f="+str(opts.f)+" -> 0<=f<1"
    exit(-1)
if opts.dir in DIRS:
    print "Figure directory found..."
else:
    print "Figure directory not found. Created."
    os.mkdir(opts.dir)
##### READ LOG FILE FOR OMITTED PROBES
print "Reading omitted probes in log file..."
f = open(opts.flog,"r")
omitted = list()
for i in f:
    if "locus" in i:
        omitted.append(i.split("locus ")[1].split(" can")[0])
print "\t"+str(len(omitted))+" omitted probes"
f.close()
##### MAIN
xy = open(opts.fxy,'r')
ixy= xy.readline()
snp= open(opts.fsnp,'r')
isnp= snp.readline()
cnv= open(opts.fcnv,'r')
icnv= cnv.readline()
for ixy in xy:
    ixy = ixy.strip("\n").replace("NeuN","-1").split("\t")
    snpp= ixy[0]
    ch  = ixy[1]
    bp  = ixy[2]
    XY  = ixy[3:]
    if ((snpp in omitted)==False):
        isnp = snp.readline().strip("\n").split(" ")
        icnv = cnv.readline().strip("\n").split(" ")
        isnph= isnp[0:4]
        GT   = isnp[4:]
        icnvh= icnv[0:5]
        CN   = icnv[5:]
        T    = ["",""]
        T[0] = isnph[0]+"(chr"+isnph[1]+":"+isnph[2]+")\nSNP GENOTYPING: QC="+isnph[3]
        T[1] = icnvh[0]+"(chr"+icnvh[1]+":"+icnvh[2]+")\nCNV GENOTYPING: NC="+icnvh[3]+"-MR="+icnvh[4]
        c    = plotres(XY,GT,CN,opts.f,T,opts.dir+"/chr"+str(ch)+"_"+str(bp)+"."+snpp+".png",[opts.TD,opts.TA])
    else:
        print snpp+" has been omitted by GStream (not plotted)"
