with(LinearAlgebra);
with(Statistics):

Stats := module()
option package;
export randelt,randf,randrange,randi,rowsum,colsum,rownorm,colnorm,sample,rsubset,rv,tnorm,rsphere,dker,cis,meanstd,normdata,subpca,pca,meannorm,meancent,svd,mpca,normsamp,sampgauss,covmat,rowsamp,getpca,getsvd,randrow,normrows,sampfin;
local sqrt2,rt2pi,ModuleLoad;

    randelt := proc(xl)
    local N;
    uses LinearAlgebra;
        if(type(xl,`..`)) then
            a,b := op(xl);
            return (rand() mod (b-a+1))+a;
        end if;
        if(type(xl,'Matrix')) then
            N := Dimension(xl)[1];
        else
            N := numelems(xl);
        end if;
        return xl[rand() mod N+1];
    end proc;

    randf := proc(a,b)
    local ans,N1;
        N1 := 10^Digits;
        if(nargs=0) then
            return procname(0,1);
        elif(nargs=1) then
            return procname(0,args[1]);
        end if;
        ans := a+(b-a)*(rand() mod N1+.5)/N1;
        return ans;
    end proc;

    binsearch_c := proc(c::float[8],tt::Array(datatype=float[8]),N::integer[4])
        l := 1;
        r := N;
        while(l<r) do
            m := floor((l+r)/2);
            if(c<tt[m]) then
                r := m;
            else
                l := m+1;
            end if;
        end do;
        return l;
    end proc;

    binsearch_c := Compiler:-Compile(binsearch_c);

    sampfin_alloc := proc(N)
    option remember;
        return Vector(N,datatype=float[8]);
    end proc;

    sampfin := proc(aa,J)
        N := numelems(aa);
        if(type(J,'numeric')) then
            return procname(aa,Vector(J,datatype=integer[4]));
        end if;
        M := Dimension(J);
        tt := sampfin_alloc(2^ceil(log[2](N)));
        tt[1] := aa[1];
        for i from 2 to N do
            tt[i] := tt[i-1]+aa[i];
        end do;
        c := tt[N];
        for i from 1 to M do
            J[i] := binsearch_c(randf(0,c),tt,N);
        end do;
        return J;
    end proc;

    randrange := proc(rng,N)
        if(type(rng,`..`) or type(rng[1],'numeric')) then
            if(nargs=2) then
                if(type(N,'numeric')) then
                    V := vecf(N);
                else
                    V := args[2];
                end if;
                for i from 1 to Dimension(V) do
                    V[i] := randf(op(rng));
                end do;
                return V;
            end if;
            return randf(op(rng));
        end if;
        m := nops(rng);
        if(nargs=2) then
            if(type(N,'numeric')) then
                B := matf(N,m);
            else
                B := args[2];
            end if;
            for i from 1 to Dimension(B)[1] do
                for j from 1 to m do
                    B[i,j] := randf(op(rng[j]));
                end do;
            end do;
            return B;
        end if;
        return [seq(randf(op(r)),r=rng)];
    end proc;

    randrow := proc(A)
        N,d := Dimension(A);
        return A[rand() mod N+1];
    end proc;

    rowsum := proc(C)
        m,n := Dimension(C);
        pi := Vector([seq(add(C[i,j],j=1..n),i=1..m)],datatype=float[8]);
        return pi;
    end proc;

    colsum := proc(C)
        m,n := Dimension(C);
        pi := Vector([seq(add(C[i,j],i=1..m),j=1..n)],datatype=float[8]);
        return Transpose(pi);
    end proc;

    rownorm := proc(C)
        pi := rowsum(C);
        P := DiagonalMatrix(pi,datatype=float[8])^(-1).C;
        if(nargs=2 and args[2]=true) then
            return P,pi;
        else
            return P;
        end if;
    end proc;

    normrows := proc(A)
        N,d := Dimension(A);
        for i from 1 to N do
            r := sqrt(add(A[i,j]^2,j=1..d));
            for j from 1 to d do
                A[i,j] := A[i,j]/r;
            end do;
        end do;
        return A;
    end proc;

    colnorm := proc(C)
        pi := colsum(C);
        P := C.DiagonalMatrix(pi,datatype=float[8])^(-1);
        if(nargs=2 and args[2]=true) then
            return P,pi;
        else
            return P;
        end if;
    end proc;

    cis := proc(t)
        if(nargs=0) then
            return procname(randf(0,2*Pi));
        end if;
        return [evalf(cos(t)),evalf(sin(t))];
    end proc;

    rsubset := proc(m,n)
        if(m>n) then
            error "rsubset";
        end if;
        ansa := {seq(i,i=1..n)};
        ans := {};
        while(nops(ans)<m) do
            ansb := ansa minus ans;
            l := nops(ansb);
            il := [seq(rand() mod l+1,i=1..m-nops(ans))];
            ans := {op(ans),op(il)};
        end do;
        return [op(ans)];
    end proc;

    randi := proc(n)
        if(type(args[1],`numeric`)) then
            return randi(1..n);
        else
            a,b := op(args[1]);
            return a+(rand() mod (b-a+1));
        end if;
    end proc;

    sample := proc(pl,n)
    local ans,i,j,N,rl,p;
        if(nargs=1) then
            return sample(args[1],1)[1];
        end if;
        N := numelems(pl);
        rl := Vector(N);
        rl[1] := pl[1];
        for i from 2 to N do
            if(pl[i]<0) then
                error "negative probability";
            end if;
            rl[i] := rl[i-1]+pl[i];
        end do;
        p := rl[N];
        ans := [];
        for i from 1 to n do
            r := randf(0,p);
            for j from 1 to N do
                if(r<rl[j]) then
                    ans := [op(ans),j];
                    break;
                end if;
            end do;
            if(i mod 50=0) then
                print(i);
            end if;
        end do;
        return ans;
    end proc;

    rv := proc(kf,r)
        md := module()
        option object;
        export kf,a,b,samp,cdf;
        local y,init,ModuleApply,ModulePrint;
            init::static := proc()
            option remember;
            local x;
                a,b := evalf(a),evalf(b);
                cdf := int(kf(x),x=a..y)/int(kf(x),x=a..b);
                protect('kf','a','b','cdf');
            end proc;
            ModuleApply::static := proc()
                if(nargs=1) then
                    return [seq(samp(),i=1..args[1])];
                else
                    return samp();
                end if;
            end proc;
            ModulePrint := proc()
                s := "random variable";
                return nprintf(s);
            end proc;
            samp::static := proc()
            local z;
                init();
                z := randf(0,1);
                ans := [fsolve(cdf=z)];
                for x in ans do
                    if(a<=x and x<=b) then
                        return x;
                    end if;
                end do;
                error "no sample";
            end proc;
        end module;
        md:-kf := kf;
        if(nargs=2) then
            md:-a,md:-b := op(r);
        elif(nargs=3) then
            md:-a,md:-b := args[2..3];
        end if;
        return md;
    end proc;

    tnorm := proc(r)
        a,b := op(r);
        if(b<a) then
            error "empty range";
        elif(a<-10) then
            a := -10.0;
        elif(b>10) then
            b := 10.0;
        end if;
        if(type(procname,indexed)) then
            mu,sig := op(procname);
            ans := tnorm((a-mu)/sig..(b-mu)/sig);
            return mu+sig*ans;
        end if;
        return tnorm0(a,b);
    end proc;

    tnorm0 := proc(a,b)
        if(a<0 and b>0) then
            ans1 := erfn(-a)-.5;
            ans2 := erfn(b)-.5;
            ans3 := ans1+ans2;
            if(randf(0,1)<ans1/ans3) then
                return -tnorm0(0,-a);
            else
                return tnorm0(0,b);
            end if;
        elif(a<0 and b<0) then
            return -tnorm0(-b,-a);
        end if;
        return tnorm1(a,b);
    end proc;

    tnorm1 := proc(a,b)
        c := expn(a);
        while(true) do
            x,y := randf(a,b),randf(0,c);
            if(y<=expn(x)) then
                return x;
            end if;
        end do;
    end proc;

    expn := proc(x)
        return exp(-evalf(x)^2/2)/rt2pi;
    end proc;

    rt2pi := evalf(sqrt(2*Pi));

    erfn := proc(x)
        return (1+erf(x/sqrt2))/2;
    end proc;

    sqrt2 := evalf(sqrt(2));

    gaussian := proc(dim)
        argl := [args];
        md := module()
        option object;
        export samp,p,dim,setU1,rl,A;
            dim := argl[1];
            samp::static := proc()

            end proc;
        end module;
        return md;
    end proc;

    dker := proc()
        if(not type(procname,indexed)) then
            return dker['gaussian'](args);
        end if;
        dtype := op(procname);
        if(dtype='correlation') then
            return dker1(args);
        else
            error;
        end if;
    end proc;

    dker1 := proc(ds,h)
        argl := [args];
        md := module()
        option object;
        export p,q,n,h,samp,draw;
            p,h := op(argl);
            n := p:-n;
            p := regularize['meannorm'](p);
            q := Object(p);
            samp::static := proc()
                q[..] := rsphere(n);
                c := p.q/n;
                if(c<0) then
                    for i from 1 to n do
                        q:-vect[i] := -q:-vect[i];
                    end do;
                    c := -c;
                end if;
                for i from 1 to n do
                    q:-vect[i] := q:-vect[i]+c*p:-vect[i]*(1/h-1);
                end do;
                regularize['meannorm'](q,true);
                return q;
            end proc;
            draw::static := proc()
                samp();
                A := dframe(n,2);
                A[..,1] := p;
                A[..,2] := q;
                return dplot(A);
            end proc;
        end module;
        return md;
    end proc;

    #random point in the ball
    rsphere := proc(n)
    uses Statistics;
        X := rsphere0();
        ans := Sample(X,n);
        r := sqrt(add(ans[i]^2,i=1..n));
        for i from 1 to n do
            ans[i] := ans[i]/r;
        end do;
        return ans;
    end proc;

    rsphere0 := proc()
    option remember;
    uses Statistics;
        return RandomVariable(Normal(0,1));
    end proc;

    meanstd0 := proc(A::Array(datatype=float[8]),V::Array(datatype=float[8]),rl::Array(datatype=float[8]),N::integer[4],M::integer[4])
        for j from 1 to M do
            c := 0.0;
            for i from 1 to N do
                c := c+A[i,j];
            end do;
            c := c/N;
            V[j] := c;
            r := 0.0;
            for i from 1 to N do
                r := r+(A[i,j]-c)^2;
            end do;
            rl[j] := sqrt(r/N);
        end do;
    end proc;

    meanstd0 := Compiler:-Compile(meanstd0);

    meanstd1 := proc(A::Array(datatype=float[8]),rho::Array(datatype=float[8]),V::Array(datatype=float[8]),rl::Array(datatype=float[8]),N::integer[4],M::integer[4])
        d := 0.0;
        for i from 1 to N do
            d := d+rho[i];
        end do;
        for j from 1 to M do
            c := 0.0;
            for i from 1 to N do
                c := c+A[i,j]*rho[i];
            end do;
            c := c/d;
            V[j] := c;
            r := 0.0;
            for i from 1 to N do
                r := r+(A[i,j]-c)^2*rho[i];
            end do;
            rl[j] := sqrt(r/d);
        end do;
    end proc;

    meanstd1 := Compiler:-Compile(meanstd1);

    meanstd := proc(A,rho)
        if(type(args[1],'Vector'(datatype=float[8]))) then
            N := Dimension(A);
            if(nargs=1) then
                c := convert(A,`+`)/N;
                r := sqrt(add((A[i]-c)^2,i=1..N)/N);
                return c,r;
            else
                d := convert(rho,`+`);
                c := DotProduct(A,rho)/d;
                r := sqrt(add((A[i]-c)^2*rho[i],i=1..N)/d);
                return c,r;
            end if;
        elif(type(args[1],'Matrix'(datatype=float[8]))) then
            N,M := Dimension(A);
            if(type(procname,indexed)) then
                V,rl := op(procname);
            else
                V := Vector(M,datatype=float[8]);
                rl := Vector(M,datatype=float[8]);
            end if;
            if(nargs=1) then
                meanstd0(A,V,rl,N,M);
                return V,rl;
            else
                #rho := args[2];
                meanstd1(A,rho,V,rl,N,M);
                return V,rl;
            end if;
        else
            error;
        end if;
    end proc;

    normdata0 := proc(A::Array(datatype=float[8]),V::Array(datatype=float[8]),N::integer[4],M::integer[4])
        for j from 1 to M do
            c := V[j];
            for i from 1 to N do
                A[i,j] := A[i,j]-c;
            end do;
        end do;
    end proc;

    normdata0 := Compiler:-Compile(normdata0);

    normdata1 := proc(A::Array(datatype=float[8]),rl::Array(datatype=float[8]),N::integer[4],M::integer[4])
        for j from 1 to M do
            r := rl[j];
            for i from 1 to N do
                A[i,j] := A[i,j]/r;
            end do;
        end do;
    end proc;

    normdata1 := Compiler:-Compile(normdata1);

    normdata := proc(A,V,rl)
        N,M := Dimension(A);
        if(not type(procname,indexed) or op(1,procname)=false) then
            A1 := Matrix(N,M,datatype=float[8]);
            ArrayTools:-Copy(A,A1);
        elif(op(1,procname)=true) then
            A1 := A;
        elif(type(op(procname),'Matrix'(datatype=float[8]))) then
            A1 := op(procname);
        end if;
        normdata0(A1,V,N,M);
        if(nargs=3) then
            normdata1(A1,rl,N,M);
        end if;
        return A1;
    end proc;

    meannorm := proc(A,rho)
        V,rl := meanstd(args);
        return normdata(A,V,rl);
    end proc;

    meancent := proc(A,rho)
        V,rl := meanstd(args);
        return normdata(A,V);
    end proc;

    mpca := proc(Al,B,rho)
        if(not type(args[1],'list')) then
            return mpca([args[1]],args[2..nargs])[1];
        end if;
        V,rl := meanstd(args[2..nargs]);
        B1 := normdata(B,V);
        N,M := Dimension(B);
        if(nargs=3) then
            for i from 1 to N do
                B1[i,..] := B1[i,..]*sqrt(rho[i]);
            end do;
        end if;
        U2 := Transpose(svd(B1)[3]);
        ans := [];
        for A in Al do
            ans := [op(ans),A.U2];
        end do;
        return ans;
    end proc;

    getpca := proc(A,M)
        R := PCA(A);
        ans := R:-principalcomponents[..,1..M];
        return ans;
    end proc;

    svd := proc(A)
        tol := .0000001;
        N,M := Dimension(A);
        d := min(N,M);
        if(N<=M) then
            U1,S1 := SingularValues(A,output=['U','S']);
            while(S1[d]<tol) do
                d := d-1;
            end do;
            S1 := S1[1..d];
            D1 := DiagonalMatrix(S1,datatype=float[8],storage=diagonal);
            U1 := U1[1..N,1..d];
            Vt1 := D1^(-1).Transpose(U1).A;
            return U1,D1,Vt1;
        else
            Vt1,S1 := SingularValues(A,output=['Vt','S']);
            while(S1[d]<tol) do
                d := d-1;
            end do;
            S1 := S1[1..d];
            D1 := DiagonalMatrix(S1,datatype=float[8],storage=diagonal);
            Vt1 := Vt1[1..d,1..M];
            U1 := A.Transpose(Vt1).D1^(-1);
            return U1,D1,Vt1;
        end if;
    end proc;

    getsvd := proc(A,d1)
        C := Transpose(A).A;
        U1 := SingularValues(C,output='U')[..,1..d1];
        S1 := SingularValues(C,output='S')[1..min(d1)];
        #D1 := DiagonalMatrix([seq(1/sqrt(x),x=S1)],datatype=float[8]);
        A1 := A.U1;
        print(convert(S1,'list'));
        return A1;
    end proc;

    normsamp := proc(V,h,N)
    uses Statistics;
        if(type(V,'numeric')) then
            return normsamp(Vector(args[1],datatype=float[8]),h,N);
        end if;
        n := Dimension(V);
        if(type(args[3],'Matrix')) then
            B := args[3];
        else
            B := Matrix(N,n,datatype=float[8]);
            procname(V,h,B);
            return B;
        end if;
        if(type(h,'Vector') or type(h,'Matrix')) then
            return normsamp[h](V,1.0,N);
        end if;
        Y := RandomVariable(Normal(0,h));
        Sample(Y,B);
        if(type(procname,indexed) and nops(procname)=1) then
            H := op(procname);
            if(type(H,'Matrix'(shape=diagonal))) then
                L := map(sqrt,H);
            else
                L := LUDecomposition(H,method='Cholesky');
            end if;
            Multiply(B,L^(-1),'inplace');
        end if;
        normdata[true](B,-V);
        return B;
    end proc;

    sampgauss0 := proc(H,h,bb,N)
    local x;
        m := Dimension(H)[1];
        U1,S1 := SingularValues(H,output=['U','S']);
        D1 := DiagonalMatrix(S1,datatype=float[8],storage=diagonal);
        A1 := U1.DiagonalMatrix(map(x->1/sqrt(x),S1),datatype=float[8]);
        if(type(args[1],'numeric')) then
            B := allocla[float[8]]([N,m]);
        else
            B := args[4];
        end if;
        Sample(Normal(0,h),B);
        rmult(B,Transpose(A1));
        V := U1.D1^(-1).Transpose(U1).bb; #just H^(-1).bb
        shiftrows(B,-V);
        return B;
    end proc;

    sampgauss1 := proc(H,h,bb,A,V,N)
        U,E := linsolve(A,V);
        E1 := Transpose(E);
        B := sampgauss0(E1.H.E,h,E1.bb+E1.H.U,N);
        B := B.E1;
        shiftrows(B,U);
        return B;
    end proc;

    #sample from the gaussian exp(-(x^t.H.x/2+bb.x)) restricted to the
    #subspace given by A
    sampgauss := proc(H,A,N)
        if(type(args[1],'numeric')) then
            return procname(IdentityMatrix(args[1],datatype=float[8]),args[2..nargs]);
        elif(type(args[1],'Matrix')) then
            m := Dimension(H)[1];
            return procname([H,allocla[float[8]](m)],args[2..nargs]);
        end if;
        H1,bb1 := op(H);
        m := Dimension(H1)[1];
        if(type(procname,indexed)) then
            h := op(procname);
        else
            h := 1.0;
        end if;
        if(nargs=2) then
            return sampgauss0(H1,h,bb1,args[2]);
        elif(type(A,'Matrix')) then
            return procname(H,[A,allocla[float[8]](m)],args[3..nargs]);
        end if;
        A1,V1 := op(A);
        return sampgauss1(H1,h,bb1,A1,V1,N);
    end proc;

    covmat := proc(A)
        N,m := Dimension(A);
        B := meancent(A);
        return Transpose(B).B/N;
    end proc;

    rowsamp := proc(A,n)
        N,m := Dimension(A);
        ans := matf(n,m);
        for i from 1 to n do
            i1 := rand() mod N+1;
            for j from 1 to m do
                ans[i,j] := A[i1,j];
            end do;
        end do;
        return ans;
    end proc;

end module;

