DualAlpha := module()
option package;
export pcech,pdiag,nervepd,alphaplex,alpharadii,witmap,pdweight,pdrange,pdradii;

#compute the maximum degree of any vertex in the graph which is the
#nerve of the corresponding weighted ball covering, for
#preprocessing.
    pcech0 := proc(A::Array(datatype=float[8]),pow::Array(datatype=float[8]),a1::float[8],N::integer[4],m::integer[4])
        N1 := 0;
        for i1 from 1 to N do
            s1 := pow[i1]+a1;
            if(s1<0.0) then
                next;
            end if;
            s1 := sqrt(s1);
            for i2 from i1+1 to N do
                s2 := pow[i2]+a1;
                if(s2<0.0) then
                    next;
                end if;
                s2 := sqrt(s2);
                s := s1+s2;
                s := s*s;
                r := 0.0;
                for j from 1 to m do
                    c := A[i1,j]-A[i2,j];
                    r := r+c*c;
                    if(r>s) then
                        break;
                    end if;
                end do;
                if(r<=s) then
                    N1 := N1+1;
                end if;
            end do;
        end do;
        return N1;
    end proc;

    pcech0 := Compiler:-Compile(pcech0);

#compute the above intersection graph by storing the neighbors to
#every vertex into E.
    pcech1 := proc(A::Array(datatype=float[8]),pow::Array(datatype=float[8]),a1::float[8],E::Array(datatype=integer[4]),N::integer[4],m::integer[4])
        N1 := 0;
        for i1 from 1 to N do
            p1 := pow[i1];
            s1 := p1+a1;
            if(s1<0.0) then
            next;
            end if;
            s1 := sqrt(s1);
            k := 0;
            for i2 from i1+1 to N do
                if(i2=i1) then
                    next;
                end if;
                p2 := pow[i2];
                s2 := p2+a1;
                if(s2<0.0) then
                    next;
                end if;
                s2 := sqrt(s2);
                s := s1+s2;
                s := s*s;
                r := 0.0;
                b := 0.0;
                for j from 1 to m do
                    c := A[i1,j]-A[i2,j];
                    r := r+c*c;
                    if(r>s) then
                        break;
                    end if;
                end do;
                if(r<=s) then
                    N1 := N1+1;
                    E[N1,1] := i1;
                    E[N1,2] := i2;
                end if;
            end do;
        end do;
    end proc;

    pcech1 := Compiler:-Compile(pcech1);

    pcech := proc(A,pow,a1)
        N,d := Dimension(A);
        N1 := pcech0(A,pow,a1,N,d);
        E := allocla[integer[4]]([N1,2]);
        pcech1(A,pow,a1,E,N,d);
        return E;
    end proc;

    adjdata0 := proc(E::Array(datatype=integer[4]),E1::Array(datatype=integer[4]),J1::Array(datatype=integer[4]),J2::Array(datatype=integer[4]),N::integer[4],n::integer[4])
        for i from 1 to N do
            j1 := E[i,1];
            j2 := E[i,2];
            J2[j1] := J2[j1]+1;
            J2[j2] := J2[j2]+1;
        end do;
        J1[1] := 1;
        for i from 2 to n do
            J1[i] := J1[i-1]+J2[i-1];
        end do;
        for i from 1 to n do
            J2[i] := 0;
        end do;
        for k from 1 to N do
            i := E[k,2];
            k1 := J1[i]+J2[i];
            E1[k1,1] := i;
            E1[k1,2] := E[k,1];
            J2[i] := J2[i]+1;
        end do;
        for k from 1 to N do
            i := E[k,1];
            k1 := J1[i]+J2[i];
            E1[k1,1] := i;
            E1[k1,2] := E[k,2];
            J2[i] := J2[i]+1;
        end do;
        return;
    end proc;

    adjdata0 := Compiler:-Compile(adjdata0);

    adjdata := proc(E,n)
        N := Dimension(E)[1];
        E1,J1,J2 := allocla[integer[4]]([2*N,2],n,n);
        adjdata0(E,E1,J1,J2,N,n);
        return E1,J1,J2;
    end proc;

    loadineqs := proc(i0::integer[4],S::Array(datatype=float[8]),pow::Array(datatype=float[8]),scale::float[8],E1::Array(datatype=integer[4]),J::Array(datatype=integer[4]),degs::Array(datatype=integer[4]),A::Array(datatype=float[8]),V::Array(datatype=float[8]),inds1::Array(datatype=integer[4]),inds2::Array(datatype=integer[4]),n::integer[4],d::integer[4],N::integer[4])
        s := 1/scale;
        k1 := J[i0]-1;
        m := degs[i0];
        for k from 1 to n do
            inds2[k] := 0;
        end do;
        for k from 1 to m do
            j := E1[k1+k,2];
            inds1[k] := j;
            inds2[j] := k;
        end do;
        for k from 1 to m do
            i := inds1[k];
            for j from 1 to d do
                A[k,j] := s*S[i,j]-s*S[i0,j];
            end do;
        end do;
        for k from 1 to m do
            i := inds1[k];
            c := s*s*pow[i0]-s*s*pow[i];
            for j from 1 to d do
                c := c+A[k,j]*A[k,j];
            end do;
            V[k] := c/2;
        end do;
        return m;
    end proc;

    loadineqs := Compiler:-Compile(loadineqs);

    pdiag0 := proc(i0::integer[4],S::Array(datatype=float[8]),X::Array(datatype=float[8]),wit::Array(datatype=float[8]),scale::float[8],d::integer[4])
        for j from 1 to d do
            wit[j] := S[i0,j]+scale*X[j];
        end do;
    end proc;

    pdiag0 := Compiler:-Compile(pdiag0);

    pdiag := proc(S,pow,a1)
        md := module()
        option object;
        export S,pow,a1,E,d,n,N,m,m1,i0,getadj,getcech,getdim,maxweight,numsites,numedges,getdeg,maxdeg,J,degs,inds1,inds2,setcell,setface,getweight,getwit,wit,A,V,E1,cc,init,ls,scale,`whattype`,getactive,getradii;
        local ModulePrint;
            `whattype`::static := proc()
                return 'PowDiag';
            end proc;
            ModulePrint::static := proc()
                return nprintf("power diagram, %d points in R^%d",n,d);
            end proc;
            numsites::static := proc()
                return n;
            end proc;
            maxweight::static := proc()
                return a1;
            end proc;
            numedges::static := proc()
                return N;
            end proc;
            getdim::static := proc()
                return d;
            end proc;
            getdeg::static := proc(i)
                if(nargs=1) then
                    return m1;
                else
                    return m;
                end if;
            end proc;
            getradii::static := proc()
                return vecf([seq(sqrt(max(0.0,a1+pow[i])),i=1..n)]);
            end proc;
            getdegs::static := proc()
                return degs;
            end proc;
            maxdeg::static := proc()
                return max(degs);
            end proc;
            getcech::static := proc()
            option remember;
                E0 := Matrix([seq([i],i=1..n)],datatype=integer[4]);
                return E0,E;
            end proc;
            getadj::static := proc(i)
                k1,k2 := J[i],J[i]+degs[i]-1;
                return E1[k1..k2,2];
            end proc;
            getcell::static := proc()
                return i0;
            end proc;
            setcell::static := proc(i)
                i0 := i;
                a := Float(infinity);
                m := loadineqs(i0,S,pow,scale,E1,J,degs,A,V,inds1,inds2,n,d,N);
                ls:-loadcoeffs(A,V,m);
                return;
            end proc;
            setface::static := proc(sig)
                l := nops(sig);
                sig1 := [seq(inds2[i],i=sig)];
                objmax := (a1/scale^2+pow[i0]/scale^2)/2;
                ls:-seteqs(sig1);
                if(0 in sig1 or not ls:-solve(objmax)) then
                    return false;
                end if;
                return true;
            end proc;
            getactive::static := proc()
                return [i0,seq(inds1[i],i=ls:-getactive())];
            end proc;
            getweight::static := proc()
                return 2*ls:-r*scale^2-pow[i0];
            end proc;
            getwit::static := proc()
                X := ls:-getpoint();
                pdiag0(i0,S,X,wit,scale,d);
                return wit;
            end proc;
            getcoeffs::static := proc()
                return ls:-getcoeffs();
            end proc;
            init::static := proc()
                S,pow,a1 := args;
                if(a1+max(pow)<0) then
                    error "no nonempty cells";
                end if;
                scale := sqrt(a1+max(pow));
                n,d := Dimension(S);
                E := pcech(S,pow,a1);
                N := Dimension(E)[1];
                E1,J,degs := adjdata(E,n);
                m1 := max(degs);
                A,V,wit,cc := allocla[float[8]]([m1,d],m1,d,n);
                inds1,inds2 := allocla[integer[4]](m1,n);
                i0 := 0;
                ls := lsquares(d,m1);
                m := infinity;
                return;
            end proc;
        end module;
        md:-init(args);
        return md;
    end proc;

    pdweight0 := proc(xx::Array(datatype=float[8]),S::Array(datatype=float[8]),pow::Array(datatype=float[8]),n::integer[4],d::integer[4])
        ans := Float(infinity);
        for i from 1 to n do
            c := -pow[i];
            for j from 1 to d do
                c := c+(xx[j]-S[i,j])^2;
            end do;
            ans := min(ans,c);
        end do;
        return ans;
    end proc;

    pdweight0 := Compiler:-Compile(pdweight0);

#the weight function of the power diagram
    pdweight := proc(S,pow)
        if(whattype(args[1])='PowDiag') then
            pd := args[1];
            return pdweight(pd:-S,pd:-pow);
        end if;
        md := module()
        option object;
        export S,pow,n,d,init;
        local ModuleApply,ModulePrint,xx0;
            ModulePrint::static := proc()
                return nprintf("weight map R^%d->R",d);
            end proc;
            ModuleApply::static := proc(xx)
                for j from 1 to d do
                    xx0[j] := xx[j];
                end do;
                return pdweight0(xx0,S,pow,n,d);
            end proc;
            init::static := proc()
                S,pow := args;
                n,d := Dimension(S);
                xx0 := vecf(d);
            end proc;
        end module;
        md:-init(S,pow);
        return md;
    end proc;

#total range of all points in the union of the balls
    pdrange := proc(S,pow,a1)
        if(whattype(args[1])='PowDiag') then
            pd := args[1];
            return pdrange(pd:-S,pd:-pow,pd:-a1);
        end if;
        n,d := Dimension(S);
        ans := [];
        for j from 1 to d do
            x0,x1 := Float(infinity),-Float(infinity);
            for i from 1 to n do
                if(a1+pow[i]<=0.0) then
                    next;
                end if;
                r := sqrt(a1+pow[i]);
                x0 := min(x0,S[i,j]-r);
                x1 := max(x1,S[i,j]+r);
            end do;
            ans := [op(ans),x0..x1];
        end do;
        return ans;
    end proc;

    nervepd0 := proc(pd,Sig)
        N1,l := Dimension(Sig);
        n,d,a1 := pd:-numsites(),pd:-getdim(),pd:-maxweight();
        N := 0;
        i := 0;
        T := allocla[integer[4]]([N1,l]);
        aa,W := allocla[float[8]](N1,[N1,d]);
        for i1 from 1 to N1 do
            sig := [seq(Sig[i1,j],j=1..l)];
            if(sig[1]<>i) then
                i := sig[1];
                pd:-setcell(i);
            end if;
            tprint[10]("%d/%d vertices, %d simplices...",i,n,N);
            if(not pd:-setface(sig[2..l])) then
                next;
            end if;
            a := pd:-getweight();
            wit := pd:-getwit();
            if(a<=a1) then
                N := N+1;
                for j from 1 to l do
                    T[N,j] := sig[j];
                end do;
                aa[N] := a;
                setrowc(W,N,wit,d);
            end if;
        end do;
        return T[1..N],aa[1..N],W[1..N];
    end proc;

    nervepd1 := proc(pd,k1)
        n,d,a1 := pd:-numsites(),pd:-getdim(),pd:-maxweight();
        X := fplex(n);
        Phi := witmap(X,d);
        for k from 0 to k1 do
            l := k+1;
            if(k<=1) then
                E := pd:-getcech()[l];
            else
                E := getlazy(X:-getmat(k-1));
            end if;
            N1 := Dimension(E)[1];
            N := 0;
            for i1 from 1 to N1 do
                sig := [seq(E[i1,j],j=1..l)];
                if(sig[1]<>i) then
                    i := sig[1];
                    pd:-setcell(i);
                end if;
                tprint[10]("%d/%d vertices, %d simplices...",i,n,N);
                if(not pd:-setface(sig[2..l])) then
                    next;
                end if;
                a := pd:-getweight();
                wit := pd:-getwit();
                if(a<=a1) then
                    N := N+1;
                    j := X:-addfilt(sig,a);
                    Phi:-setelt(j,wit);
                end if;
            end do;
        end do;
        return X,Phi;
    end proc;

    nervepd := proc(pd)
        if(not type(procname,indexed) or op(procname)=false) then
            return nervepd[true](args)[1];
        end if;
        if(nargs=1) then
            return procname(pd,pd:-d);
        elif(type(args[2],'Matrix')) then
            return nervepd0(args);
        elif(type(args[2],'numeric')) then
            return nervepd1(args);
        else
            error;
        end if;
    end proc;

    pdradii := proc(S,R)
        n,d := Dimension(S);
        if(type(R,'numeric')) then
            return procname(S,vecf([seq(R,i=1..n)]));
        end if;
        pow := vecf([seq(R[i]^2,i=1..n)]);
        return pdiag(S,pow,0.0);
    end proc;

#compute the alpha complex
    alphaplex := module()
    option object;
    export primtol,singtol,getalpha,rescale,`?[]`;
    local ModuleApply;
        getalpha::static := proc(S,pow,a1,k1)
            if(not type(procname,indexed)) then
                return getalpha[false](args);
            end if;
            flag := op(procname);
            if(type(args[1],'Matrix')) then
                pd := pdiag(S,pow,a1);
                pd:-ls:-primtol := primtol;
                pd:-ls:-singtol := singtol;
                if(not rescale) then
                    pd:-scale := 1.0;
                end if;
                return getalpha[flag](pd,k1);
            end if;
            return nervepd[flag](args);
        end proc;
        `?[]`::static := proc()
            return getalpha[op(args[2])];
        end proc;
        ModuleApply::static := getalpha;
        primtol := .000001;
        singtol := .00000001;
        rescale := true;
    end module;

    alpharadii := proc(S,R,k1)
        return alphaplex(pdradii(S,R),k1);
    end proc;

#witness map from X to R^m, which sends each simplex to its unique
#minimal representative
    witmap := proc(X,m)
        md := module()
        option object;
        export X,d,A,setelt,getvec,getelt,init;
        local dyn,ModulePrint,ModuleApply,flags;
            ModulePrint::static := proc()
                return nprintf("witness map to R^%d",d);
            end proc;
            ModuleApply::static := proc(sig)
                return getelt(sig);
            end proc;
            getelt::static := proc(sig)
                i := X:-getind(sig);
                if(not flags[i]) then
                    error "witness map not defined";
                end if;
                return [seq(A[i,j],j=1..d)];
            end proc;
            getvec::static := proc(sig)
                i := X:-getind(sig);
                if(not flags[i]) then
                error "witness map not defined";
                end if;
                return A[X:-getind(sig)];
            end proc;
            setelt::static := proc(i,V)
                if(dyn:-allocif(i)) then
                    A,flags := dyn:-getelts();
                end if;
                flags[i] := true;
                for j from 1 to d do
                    A[i,j] := V[j];
                end do;
                return;
            end proc;
            init::static := proc()
                X,d := args;
                dyn := dynla(float[8](d),boolean);
                A,flags := dyn:-getelts();
                return;
            end proc;
        end module;
        md:-init(X,m);
        return md;
    end proc;

end module;
