AlphaCosim := module()
option package;
export normrad,normrads,sampsphere,sampcosim,cosimker,cosimdist,getcost,gradcost,cosimkde,cosimshape,shapemap,cosimfit,costmap,sampshape,greedyland,cosimland,alphafilt,alphacosim,cosimsvd,getcdf;
local radweights_c,gradcost_c,getweights_c,toshape_c,toshape_to,toshape_alloc,toshape_map,testsphere_c,normrads_c,getrads_c;

    normrad := proc(xx,flag)
        r := sqrt(l2nn(xx));
        if(type(xx,'list')) then
            ans := [seq(x/r,x=xx)];
        elif(type(xx,'Vector')) then
            ans := xx/r;
        end if;
        if(nargs=2 and flag) then
            ans := ans,r;
        end if;
        return ans;
    end proc;

    normrads_c := proc(A::Array(datatype=float[8]),N::integer[4],d::integer[4])
        for i from 1 to N do
            r := 0.0;
            for j from 1 to d do
                r := r+A[i,j]^2;
            end do;
            r := sqrt(r);
            for j from 1 to d do
                A[i,j] := A[i,j]/r;
            end do;
        end do;
    end proc;

    normrads_c := Compiler:-Compile(normrads_c);

    getrads_c := proc(A::Array(datatype=float[8]),rr::Array(datatype=float[8]),N::integer[4],d::integer[4])
        for i from 1 to N do
            r := 0.0;
            for j from 1 to d do
                r := r+A[i,j]^2;
            end do;
            rr[i] := sqrt(r);
        end do;
    end proc;

    getrads_c := Compiler:-Compile(getrads_c);

    normrads := proc(A,rr)
        N,d := Dimension(A);
        if(not type(procname,indexed) or op(procname)=false) then
            A1,rr1 := matf(A),vecf(N);
            normrads[true](A1,rr1);
            return A1,rr1;
        end if;
        if(nargs=2) then
            getrads_c(A,rr,N,d);
        end if;
        normrads_c(A,N,d);
        return;
    end proc;

#sample N points on the sphere
    sampsphere := proc(N,d)
    uses Statistics;
        if(nargs=1) then
            ans := vecf(d);
        else
            ans := matf(N,d);
        end if;
        Sample(Normal(0,1),ans);
        normrads[true](ans);
        return ans;
    end proc;

    cosimker::static := proc(xx,yy)
        h := op(procname);
        c := dotprod(xx,yy);
        if(c<=0) then
            return 0.0;
        end if;
        return c^(1/h^2);
    end proc;

#distance at which the kernel will be s
    cosimdist := proc(s)
        h := op(procname);
        if(s<0.0) then
            error;
        end if;
        c := s^(h^2);
        return sqrt(2-2*c);
    end proc;

    getcost := proc(xx,yy)
        h := op(procname);
        c := dotprod(xx,yy);
        if(c>0.0) then
            return -log(c)/h^2;
        else
            return Float(infinity);
        end if;
    end proc;

    gradcost_c := proc(xx1::Array(datatype=float[8]),yy1::Array(datatype=float[8]),h::float[8],V::Array(datatype=float[8]),d::integer[4])
        c := 0.0;
        for j from 1 to d do
            c := c+xx1[j]*yy1[j];
        end do;
        if(c<=0.0) then
            error;
        end if;
        for j from 1 to d do
            V[j] := (yy1[j]-1/c*xx1[j])/h^2;
        end do;
    end proc;

    gradcost_c := Compiler:-Compile(gradcost_c);

    gradcost_to := proc(xx,yy,h,V)
        gradcost_c(xx,yy,h,V,Dimension(xx));
    end proc;

    gradcost := proc(xx,yy)
        h := op(procname);
        d := numelems(xx);
        xx1,yy1,V := tovecf(xx),tovecf(yy),vecf(d);
        gradcost_c(xx1,yy1,h,V,d);
        return V;
    end proc;

    getweights_c := proc(xx::Array(datatype=float[8]),A::Array(datatype=float[8]),h::float[8],aa::Array(datatype=float[8]),R::Array(datatype=float[8]),cc::Array(datatype=float[8]),tt::Array(datatype=float[8]),N::integer[4],d::integer[4])
        r0 := Float(infinity);

        for i from 1 to N do
            c := 0.0;
            for j from 1 to d do
                c := c+A[i,j]*xx[j];
            end do;
            cc[i] := c;
            if(c>0.0) then
                r := -log(c)/h^2;
                R[i] := r;
                r0 := min(r0,r);
            else
                R[i] := Float(infinity);
            end if;
        end do;
        dens := 0.0;
        for i from 1 to N do
            tt[i] := aa[i]*exp(-R[i]+r0);
            dens := dens+tt[i];
        end do;
        for i from 1 to N do
            tt[i] := tt[i]/dens;
        end do;
        return r0-log(dens);
    end proc;

    getweights_c := Compiler:-Compile(getweights_c);

    gradnld_c := proc(xx::Array(datatype=float[8]),A::Array(datatype=float[8]),h::float[8],cc::Array(datatype=float[8]),tt::Array(datatype=float[8]),V::Array(datatype=float[8]),N::integer[4],d::integer[4])
        for j from 1 to d do
            v := 0.0;
            for i from 1 to N do
                c := cc[i];
                if(c<=0.0) then
                    next;
                end if;
                v := v+tt[i]*(xx[j]-A[i,j]/c);
            end do;
            V[j] := v/h^2;
        end do;
    end proc;

    gradnld_c := Compiler:-Compile(gradnld_c);

    sampcosim_c := proc(A::Array(datatype=float[8]),J::Array(datatype=integer[4]),V::Array(datatype=float[8]),B::Array(datatype=float[8]),N::integer[4],M::integer[4],d::integer[4])
        for i from 1 to M do
            i1 := J[i];
            c := 0.0;
            for j from 1 to d do
                c := c+A[i1,j]*B[i,j];
            end do;
            for j from 1 to d do
                B[i,j] := B[i,j]-c*A[i1,j];
            end do;
            r := 0.0;
            for j from 1 to d do
                r := r+B[i,j]^2;
            end do;
            r := sqrt(r);
            t := sqrt((1-V[i])/V[i])/r;
            for j from 1 to d do
                B[i,j] := A[i1,j]+t*B[i,j];
            end do;
            r := 0.0;
            for j from 1 to d do
                r := r+B[i,j]^2;
            end do;
            r := sqrt(r);
            for j from 1 to d do
                B[i,j] := B[i,j]/r;
            end do;
        end do;
    end proc;

    sampcosim_c := Compiler:-Compile(sampcosim_c);

    sampcosim_alloc := proc(M)
    option remember;
        return vecf(M),veci(M);
    end proc;

#use the beta distribution to determine the distribution on the
#expected cosine coordinate along each point on the sphere
    sampcosim := proc(A,h,aa,B)
        N,d := Dimension(A);
        if(type(B,'numeric')) then
            return sampcosim(A,h,aa,matf(B,d));
        end if;
        M := Dimension(B)[1];
        V,J := sampcosim_alloc(2^ceil(log[2](M)));
        Sample(Normal(0,1),B);
        sampfin(aa,J);
        Sample(BetaDistribution(1/2/h^2+1/2,(d-1)/2),V);
        sampcosim_c(A,J,V,B,N,M,d);
        return B;
    end proc;

    testsphere_c := proc(A::Array(datatype=float[8]),N::integer[4],d::integer[4])
        err := 0.0;
        for i from 1 to N do
            r := 0.0;
            for j from 1 to d do
                r := r+A[i,j]^2;
            end do;
            r := sqrt(r);
            err := max(err,abs(r-1.0));
        end do;
        return err;
    end proc;

    testsphere_c := Compiler:-Compile(testsphere_c);

#kernel density estimator for cosine similarity raised to the
#power of 1/h^2.
    cosimkde := proc(A,h,aa)
        if(whattype(args[1])='CosimKDE') then
            return procname(args[1]:-getdata());
        elif(nargs=2) then
            N := Dimension(A)[1];
            return cosimkde(A,h,vecf([seq(1/N,i=1..N)]));
        end if;
        if(testsphere_c(A,Dimension(A))>.000001) then
            error "data not spherical";
        end if;
        md := module()
        option object;
        export A,h,aa,`whattype`,updatesite,getsite,getdata,getdim,numpoints,getpoints,getscale,getinds,getweights,getdens,getdens_i,getnld,gradnld_to,gradnld,getnld_i,getconv,getcosts,getcosines,init,insupp,sample;
        local ModulePrint,ModuleApply,R,xx,V,cc,tt,N,d,densval,nldval,suppflag;
            ModulePrint::static := proc()
                return nprintf("spherical KDE in S^%d",d-1);
            end proc;
            `whattype`::static := proc()
                return 'CosimKDE';
            end proc;
            updatesite::static := proc(xx1)
                setvec(xx,xx1);
                nldval := getweights_c(xx,A,h,aa,R,cc,tt,N,d);
                densval := exp(-nldval);
                suppflag := convert(nldval<>Float(infinity),truefalse);
                return suppflag;
            end proc;
            getsite::static := ()->xx1;
            getdata::static := ()->A,h,aa;
            getdim::static := ()->d;
            numpoints::static := ()->N:
            getpoints::static := ()->A;
            getscale::static := ()->h;
            getweights::static := ()->aa;
            getcosines::static := ()->cc;
            getconv::static := ()->tt;
            getcosts::static := ()->R;
            getinds::static := ()->[seq(`if`(cc[i]>0.0,i,NULL),i=1..N)];
            insupp::static := ()->suppflag;
            getdens_i::static := ()->densval;
            getnld_i::static := ()->nldval;
            getdens::static := proc(xx1)
                updatesite(xx1);
                return densval;
            end proc;
            getnld::static := proc(xx1)
                updatesite(xx1);
                return nldval;
            end proc;
            ModuleApply::static := getdens;
            gradnld_to::static := proc(V)
                if(suppflag) then
                    gradnld_c(xx,A,h,cc,tt,V,N,d);
                    return V;
                end if;
                error "not in the support";
            end proc;
            gradnld::static := proc(xx1)
                updatesite(xx1);
                return gradnld_to(vecf(d));
            end proc;
            sample::static := proc(M)
                return sampcosim(A,h,aa,M);
            end proc;
            init::static := proc()
                A,h,aa := args;
                N,d := Dimension(A);
                A1,xx,R,cc,tt := allocla[float[8]]([N,d],d,N,N,N);
                densval,nldval := 0.0,Float(infinity);
                suppflag := false;
            end proc;
        end module;
        md:-init(args);
        return md;
    end proc;

    cosimsvd_c := proc(A::Array(datatype=float[8]),aa::Array(datatype=float[8]),C::Array(datatype=float[8]),N::integer[4],d::integer[4])
        for j1 from 1 to d do
            for j2 from 1 to d do
                c := 0.0;
                for i from 1 to N do
                    c := c+A[i,j1]*A[i,j2]*aa[i];
                end do;
                C[j1,j2] := c;
            end do;
        end do;
    end proc;

    cosimsvd_c := Compiler:-Compile(cosimsvd_c);

    radweights_c := proc(A::Array(datatype=float[8]),aa::Array(datatype=float[8]),N::integer[4],d::integer[4])
        for i from 1 to N do
            r := 0.0;
            for j from 1 to d do
                r := r+A[i,j]^2;
            end do;
            r := sqrt(r);
            for j from 1 to d do
                A[i,j] := A[i,j]/r;
            end do;
            aa[i] := aa[i]*r^2;
        end do;
    end proc;

    radweights_c := Compiler:-Compile(radweights_c);

    cosimsvd := proc(f,d1)
        A,h,aa := f:-getpoints(),f:-getscale(),f:-getweights();
        N,d := Dimension(A);
        C := matf(d,d);
        cosimsvd_c(A,aa,C,N,d);
        U1 := SingularValues(C,output='U')[..,1..d1];
        A1 := A.U1;
        aa1 := vecf(aa);
        radweights_c(A1,aa1,N,d1);
        f1 := cosimkde(A1,h,aa1);
        return f1,Transpose(U1);
    end proc;

    toshape_c::static := proc(xx::Array(datatype=float[8]),h::float[8],yy::Array(datatype=float[8]),d::integer[4])
        r := 0.0;
        for j from 1 to d do
            r := r+yy[j]^2;
        end do;
        r := sqrt(r);
        t := h^2*r;
        c := 1/sqrt(1+t^2);
        for j from 1 to d do
            yy[j] := c*(xx[j]-h^2*yy[j]);
        end do;
        a := 0.0;
        for j from 1 to d do
            a := a+xx[j]*yy[j];
        end do;
        return log(a)/h^2;
    end proc;

    toshape_c := Compiler:-Compile(toshape_c);

    toshape_to := proc(f,xx,yy)
        d,h := f:-getdim(),f:-getscale();
        if(f:-updatesite(xx)) then
            a := f:-getnld_i();
            f:-gradnld_to(yy);
            return a+toshape_c(xx,h,yy,d);
        else
            return Float(infinity);
        end if;
    end proc;

    toshape_alloc := proc(f,xx)
        yy := vecf(f:-getdim());
        a := toshape_to(f,xx,yy);
        return yy,a;
    end proc;

    toshape_map := proc(f,X,Y:=matf(Dimension(X)),aa:=vecf(Dimension(X)[1]))
        d,h := f:-getdim(),f:-getscale();
        M := Dimension(X)[1];
        xx,yy := vecf(d),vecf(d);
        for i from 1 to M do
            getrow_c(X,i,xx,d);
            aa[i] := toshape_to(f,xx,yy);
            setrow_c(Y,i,yy,d);
        end do;
        return Y,aa;
    end proc;

    shapemap := module()
    option object;
    export cconj_alloc,cconj_to,cconj_map;
    local ModuleApply,ModulePrint;
        ModulePrint::static := proc()
            return nprintf("cosine similarity transport map and value");
        end proc;
        cconj_to::static := toshape_to;
        cconj_map::static := toshape_map;
        cconj_alloc::static := toshape_alloc;
        ModuleApply::static := proc(f,xx)
            if(type(xx,'Vector')) then
                return toshape_alloc(args);
            elif(type(xx,'Matrix')) then
                return toshape_map(args);
            end if;
        end proc;
    end module;

    cosimfit := proc(xx,V,a)
        h := op(procname);
        d := numelems(xx);
        yy := vecf(V);
        b := toshape_c(xx,h,yy,d);
        if(nargs=2) then
            return yy;
        else
            return a+b;
        end if;
    end proc;

    costmap := proc(yy,a)
        h := op(procname);
        return ()->getcost[h](yy,args[1]+aa);
    end proc;

    sampshape := proc(f,M)
        T := f:-sample(M);
        S,aa := toshape_map(f,T,args[3..nargs]);
        return S,aa,T;
    end proc;

    #returns the result of sorting aa in reverse order
    greedyland_c := proc(S::Array(datatype=float[8]),eps::float[8],inds::Array(datatype=integer[4]),M::integer[4],d::integer[4])
        n := 0;
        for k from 1 to M do
            for i from 1 to n do
                k1 := inds[i];
                r := 0.0;
                for j from 1 to d do
                    x := S[k1,j]-S[k,j];
                    r := r+x*x;
                end do;
                r := sqrt(r);
                if(r<eps) then
                    break;
                end if;
            end do;
            if(i<n+1) then
                next;
            end if;
            n := n+1;
            inds[n] := k;
        end do;
        return n;
    end proc;

    greedyland_c := Compiler:-Compile(greedyland_c);

#go through S, in order, adding the next point with a minium separation
    greedyland := proc(S,aa,eps,a1:=Float(infinity))
        N,d := Dimension(S);
        ord := sort(aa,`<`,output=permutation);
        for N1 from N to 1 by -1 do
            a := aa[ord[N1]];
            if(a<=a1 and a<>Float(infinity)) then
                break;
            end if;
        end do;
        S1 := S[convert(ord[1..N1],'list')];
        inds := allocla[integer[4]](N1);
        n := greedyland_c(S1,eps,inds,N1,d);
        return ord[convert(inds[1..n],'list')];
    end proc;

    #return landmarks by proceding in increasing value of aa, accepting
#when distance at least eps from all chosen points.
    cosimland := proc(f,T,s,mindens:=0.0)
        if(type(T,'numeric')) then
            return cosimland(f,f:-sample(T),s,mindens);
        end if;
        h := f:-getscale();
        M := Dimension(T)[1];
        S,aa := toshape_map(f,T);
        eps := cosimdist[h](s);
        a1 := -log(mindens);
        inds := convert(greedyland(S,aa,eps,a1),'list');
        return S[inds],aa[inds],T[inds];
    end proc;

    alphafilt := proc(S,aa,r,k1)
        n := Dimension(S)[1];
        X := alpharadii(S,r,k1);
        Y := fplex(n);
        for k from 0 to k1 do
            for sig in X[k] do
                Y:-addfilt(sig,max([seq(aa[i],i=sig)]));
            end do;
        end do;
        return Y;
    end proc;

    alphacosim := proc(f,M,s,k1,mindens:=0.0)
        h := f:-getscale();
        S,aa,T := cosimland(f,M,s,mindens);
        eps := cosimdist[h](s);
        X := alphafilt(S,aa,eps,k1);
        return X,S,aa,T;
    end proc;

#get the density value at the cutoff quantile value of 0<=s<=1
    getcdf := proc(f,s,M:=10000)
        A,aa := f:-getpoints(),f:-getweights();
        N,d := Dimension(A);
        J := sampfin(aa,M);
        V := vecf([seq(f:-getnld(A[J[i]]),i=1..M)]);
        sort[inplace](V,`>`);
        return exp(-V[ceil(M*(1-s))]);
    end proc;

end module;
