ImageMaps := module()
option package;
export map2arr,drawmap,getim,vec2arr,arr2vec,arrdim,allocarr,drawarr,colormap,colpar,map2im;
local colortab,colorvec,ModuleLoad;

    ModuleLoad := proc()
    global `type/ColorMap`,`type/ImageMap`,viewmode;
        `type/ColorMap` := proc(A)
            if(whattype(A)='ColorMap') then
                return true;
            else
                return false;
            end if;
        end proc;
        `type/ImageMap` := proc(A)
            if(whattype(A)='ImageMap') then
                return true;
            else
                return false;
            end if;
        end proc;
        viewmode := false;
    end proc;

    arrdim := proc(A)
        return seq(op(2,rng),rng=ArrayTools:-Dimensions(A));
    end proc;

    allocarr := proc()
        if(not type(procname,indexed)) then
            return allocarr[float[8]](args);
        end if;
        typ := op(procname);
        ans := [];
        for ml in args do
            ans := [op(ans),Array(seq(1..m,m=ml),datatype=typ)];
        end do;
        return op(ans);
    end proc;

    vec2arr0 := proc(V::Array(datatype=float[8]),C::Array(datatype=float[8]),m1::integer[4],m2::integer[4]);
        for i from 1 to m1 do
            for j from 1 to m2 do
                C[i,j] := V[(i-1)*m2+j];
            end do;
        end do;
    end proc;

    vec2arr0 := Compiler:-Compile(vec2arr0);

    vec2arr1 := proc(V::Array(datatype=float[8]),C::Array(datatype=float[8]),m1::integer[4],m2::integer[4],m3::integer[4]);
        for i1 from 1 to m1 do
            for i2 from 1 to m2 do
                for i3 from 1 to m3 do
                    C[i1,i2,i3] := V[(i1-1)*m2*m3+(i2-1)*m3+i3];
                end do;
            end do;
        end do;
    end proc;

    vec2arr1 := Compiler:-Compile(vec2arr1);

    vec2arr := proc(V)
        if(type(procname,indexed)) then
            return vec2arr(args,op(procname));
        elif(nargs=3) then
            A,ml := args[2..nargs];
        elif(type(args[2],'Array') or type(args[2],'Matrix')) then
            A := args[2];
            ml := [seq(op(2,rng),rng=ArrayTools:-Dimensions(A))];
        else
            ml := args[2];
            A := Array(seq(1..m,m=ml),datatype=float[8]);
        end if;
        if(nops(ml)=2) then
            vec2arr0(V,A,op(ml));
        elif(nops(ml)=3) then
            vec2arr1(V,A,op(ml));
        else
            error;
        end if;
        return A;
    end proc;

    arr2vec0 := proc(C::Array(datatype=float[8]),V::Array(datatype=float[8]),m1::integer[4],m2::integer[4]);
        for i from 1 to m1 do
            for j from 1 to m2 do
                V[(i-1)*m2+j] := C[i,j];
            end do;
        end do;
    end proc;

    arr2vec0 := Compiler:-Compile(arr2vec0);

    arr2vec1 := proc(C::Array(datatype=float[8]),V::Array(datatype=float[8]),m1::integer[4],m2::integer[4],m3::integer[4]);
        for i from 1 to m1 do
            for j from 1 to m2 do
                for k from 1 to m3 do
                    V[(i-1)*m2*m3+(j-1)*m3+k] := C[i,j,k];
                end do;
            end do;
        end do;
    end proc;

    arr2vec1 := Compiler:-Compile(arr2vec1);

    arr2vec := proc(A)
        if(type(procname,indexed)) then
            return arr2vec(args,op(procname));
        end if;
        if(nargs=3) then
            V,ml := args[2..nargs];
        elif(nargs=2 and type(args[2],'list')) then
            ml := args[2];
            V := allocla(convert(ml,`*`));
        elif(nargs=2 and type(args[2],'Vector'(datatype=float[8]))) then
            V := args[2];
            ml := [arrdim(A)];
        elif(nargs=1) then
            ml := [arrdim(A)];
            V := allocla(convert(ml,`*`));
        else
            error;
        end if;
        if(nops(ml)=2) then
            arr2vec0(A,V,op(ml));
        elif(nops(ml)=3) then
            arr2vec1(A,V,op(ml));
        else
            error;
        end if;
        return V;
    end proc;

    getball0 := proc(V::Array(datatype=float[8]),r::float[8],al::Array(datatype=float[8]),bl::Array(datatype=float[8]),ml::Array(datatype=integer[4]),epsl::Array(datatype=float[8]),inds::Array(datatype=integer[4]),d::integer[4],tl::Array(datatype=integer[4]),J0::Array(datatype=integer[4]),J1::Array(datatype=integer[4]))
        tl[d] := 1;
        N0 := 1;
        for j from 1 to d do
            eps := epsl[j];
            t := ceil(r/eps);
            tl[j] := t;
            N0 := N0*(2*t+1);
        end do;
        r2 := r*r;
        N1 := 0;
        for j from 1 to d do
            J1[j] := 0;
        end do;
        for j from 1 to d do
            J0[j] := ceil((V[j]-al[j])/(bl[j]-al[j])*ml[j]);
            end do;
        for i from 1 to N0 do
            for j from d-1 to 1 by -1 do
                j1 := j+1;
                t := 2*tl[j1]+1;
                if(J1[j1]>=t) then
                    J1[j1] := J1[j1]-t;
                    J1[j] := J1[j]+1;
                end if;
            end do;
            flag := 1;
            k := 0;
            d2 := 0.0;
            for j from 1 to d do
                a := J0[j]+J1[j]-tl[j];
                if(a<1 or a>ml[j]) then
                    flag := 0;
                    break;
                end if;
                k := k*ml[j]+(a-1);
                c := al[j]+(a-.5)/ml[j]*(bl[j]-al[j])-V[j];
                d2 := d2+c*c;
            end do;
            J1[d] := J1[d]+1;
            if(flag=0 or d2>=r2) then
                next;
            end if;
            N1 := N1+1;
            inds[N1] := k+1;
        end do;
        return N1;
    end proc;

    getball0 := Compiler:-Compile(getball0);

#object for storing images as arrays in space. converts the data to
#vector form using the array to coordinates conversions, and has
#points  in space associated to each pixel.
    getim := proc(ml,rng,cmap)
        argl := [args];
        md := module()
        option object;
        export arr,getvec,setvec,cmap,ml,getrng,al,bl,d,N,draw,init,getdata,getpoint,getcoords,getind,setarr,getcell,setmap,rng,epsl,getball,getball1,clear,`whattype`;
        local al1,bl1,ml1,epsl1,tl1,J0,J1,inds0,V1,ModuleApply;
            `whattype`::static := proc()
                return 'ImageMap';
            end proc;
            ModulePrint::static := proc()
                s := "%d";
                for j from 1 to d-1 do
                    s := cat(s,"x%d");
                end do;
                s := cat(s," image");
                return nprintf(s,op(ml));
            end proc;
            ModuleApply::static := proc(x)
                il := getcell(x);
                for k from 1 to d do
                    if(il[k]<1 or il[k]>ml[k]) then
                        return 0.0;
                    end if;
                end do;
                return arr[op(il)];
            end proc;
            clear::static := proc()
                ArrayTools:-Fill(0.0,arr);
            end proc;
            getvec::static := proc()
                if(not type(procname,indexed)) then
                    return getvec[allocla(N)](procname);
                end if;
                U := op(procname);
                return arr2vec[ml](arr,U);
            end proc;
            setvec::static := proc(U)
                vec2arr[ml](U,arr);
                return;
            end proc;
            setarr::static := proc(arr1)
                ArrayTools:-Copy(arr1,arr);
            end proc;
            setmap::static := proc(f)
                for k from 1 to N do
                    il := getcoords(k);
                    x := getpoint(k);
                    arr[op(il)] := f(x);
                end do;
                return;
            end proc;
            getball::static := proc(x,r)
                for j from 1 to d do
                    V1[j] := x[j];
                end do;
                N1 := getball1(V1,r,inds0);
                return [seq(inds0[i],i=1..N1)];
            end proc;
            getball1::static := proc(V,r,inds)
                return getball0(V,r,al1,bl1,ml1,epsl1,inds,d,tl1,J0,J1);
            end proc;
            getdata::static := proc()
                return Matrix([seq(getpoint(k),k=1..N)],datatype=float[8]);
            end proc;
            getpoint::static := proc(k)
                il := getcoords(k);
                ans := [seq(evalf(al[j]+(il[j]-.5)/ml[j]*(bl[j]-al[j])),j=1..d)];
                if(not type(procname,indexed)) then
                    return ans;
                elif(type(op(procname,'Vector'(datatype=float[8])))) then
                    V := op(procname);
                    for j from 1 to d do
                        V[j] := ans[j];
                    end do;
                    return V;
                else
                    error;
                end if;
            end proc;
            getcell::static := proc(x)
                ans := [seq(ceil((x[i]-al[i])/(bl[i]-al[i])*ml[i]),i=1..d)];
            end proc;
            getcoords::static := proc(k)
                k1 := k-1;
                ans := [];
                for j from d to 1 by -1 do
                    r := k1 mod ml[j];
                    ans := [r+1,op(ans)];
                    k1 := (k1-r)/ml[j];
                end do;
                return ans;
            end proc;
            getind::static := proc(il)
                ans := il[1]-1;
                for j from 2 to d do
                    ans := ans*ml[j]+il[j]-1;
                end do;
                ans := ans+1;
                return ans;
            end proc;
            draw::static := proc()
                if(type(procname,indexed)) then
                    return draw(args,op(procname));
                end if;
                if(nargs>0) then
                    return drawarr[args](arr,cmap);
                else
                    return drawarr(arr,cmap);
                end if;
            end proc;
            getrng::static := proc()
                return seq(al[i]..bl[i],i=1..d);
            end proc;
            init::static := proc()
                if(type(args[1],'list')) then
                    ml := args[1];
                    arr := allocarr(ml);
                elif(type(args[1],'Array'(datatype=float[8]))) then
                    arr := args[1];
                    ml := [arrdim(arr)];
                end if;
                d := nops(ml);
                N := convert(ml,`*`);
                if(nargs>=2) then
                    rng := args[2];
                else
                    rng := [seq(-1..1,i=1..d)];
                end if;
                if(nargs=3) then
                    cmap := args[3];
                else
                    cmap := colormap('virdiv');
                end if;
                al := [seq(evalf(op(1,r)),r=rng)];
                bl := [seq(evalf(op(2,r)),r=rng)];
                epsl := [seq((bl[j]-al[j])/ml[j],j=1..d)];
                al1 := Vector(al,datatype=float[8]);
                bl1 := Vector(bl,datatype=float[8]);
                ml1 := Vector(ml,datatype=integer[4]);
                epsl1 := Vector(epsl,datatype=float[8]);
                inds0,J0,J1,tl1 := allocla[integer[4]](N,d,d,d);
                V1 := allocla[float[8]](d);
                return;
            end proc;
            init(op(argl));
        end module;
        return md;
    end proc;

    map2arr := proc(f,rng)
        d := nops(rng);
        if(not type(op(1,procname),'list')) then
            return map2im[[op(procname)]](args);
        end if;
        ml := op(procname);
        al,bl := [seq(evalf(op(1,rng[i])),i=1..d)],[seq(evalf(op(2,rng[i])),i=1..d)];
        ans := allocarr(ml);
        N := convert(ml,`*`);
        for k from 1 to N do
            il := ind2tens(k,ml);
            xl := [seq(al[j]+(il[j]-.5)*(bl[j]-al[j])/ml[j],j=1..d)];
            ans[op(il)] := evalf(f(op(xl)));
        end do;
        im := getim(ans,rng,args[3..nargs]);
        return im;
    end proc;

    drawarr := proc(arr,cmap:='viridis')
    global viewmode;
        if(type(cmap,'symbol') or type(cmap,'string')) then
            return procname(arr,colormap(cmap));
        elif(type(arr,'Matrix'(datatype=float[8]))) then
            return procname(Array(arr,datatype=float[8]),cmap);
        elif(type(arr,'Matrix'(datatype=integer[4]))) then
            return procname(Array(arr,datatype=integer[4]),cmap);
        end if;
        m1,m2 := seq(op(2,rng),rng=ArrayTools:-Dimensions(arr));
        if(viewmode) then
            arr1 := Array(seq(1..m,m=[m2,m1,3]),datatype=float[8]);
        else
            arr1 := Array(seq(1..m,m=[m1,m2,3]),datatype=float[8]);
        end if;
        for i1 from 1 to m1 do
            for i2 from 1 to m2 do
                if(type(arr,'Array'(datatype=float[8]))) then
                    c := cmap(arr[i1,i2]);
                elif(type(arr,'Array'(datatype=integer[4]))) then
                    c := cmap[arr[i1,i2]];
                end if;
                if(viewmode) then
                    arr1[m2-i2+1,i1,1],arr1[m2-i2+1,i1,2],arr1[m2-i2+1,i1,3] := cmap(arr[i1,i2])[];
                else
                    arr1[i1,i2,1],arr1[i1,i2,2],arr1[i1,i2,3] := cmap(arr[i1,i2])[];
                end if;
            end do;
        end do;
        A := Create(arr1);
        if(type(procname,indexed) and nops(procname)<>0) then
            fn := op(procname);
            ImageTools:-Write(fn,A);
            return;
        end if;
        return Embed(A);
    end proc;

    drawmap := proc(f,rng,cmap)
        ml := op(procname);
        im := getim(ml,args[2..nargs]);
        im:-setmap(f);
        im:-draw(densfn);
    end proc;

    colormap := proc(s)
        if(whattype(args[1])='ColorMap') then
            return colormap(args[1]:-tab0);
        elif(not type(args[1],'table')) then
            return colormap0(args);
        end if;
        argl := [s];
        md := module()
        option object;
        export `?[]`,tab,tab0,inds,getcolor,tikz,colname,f,colmap,scale,reverse,col0,col1,`whattype`;
        local ModulePrint;
            `whattype`::static := proc()
                return 'ColorMap';
            end proc;
            ModulePrint::static := proc()
                print(seq(getcolor(250*i),i=0..4));
                return nprintf(colname);
            end proc;
            ModuleApply::static := proc(x)
                return colmap(x);
            end proc;
            reverse::static := proc()
                tab := table([seq(i=tab[1000-i],i=0..1000)]);
                col0,col1 := col1,col0;
                return;
            end proc;
            `?[]`::static := proc()
                return getcolor(op(args[2]));
            end proc;
            colmap := proc(x)
                y := f(x/scale);
                i :=  round(1000*y);
                return getcolor(i);
            end proc;
            getcolor::static := proc(i)
                if(i<0) then
                    return col0;
                elif(i>1000) then
                    return col1;
                end if;
                return tab[i];
            end proc;
            tikz::static := proc(l)
                ans := "\\pgfplotsset{\n";
                ans := cat(ans,"colormap/",colname,"/.style={/pgfplots/colormap={",colname,"}{\n");
                for i from 1 to l do
                    j := round(1000*(i-1)/(l-1));
                    ans1 := nprintf("rgb=(%f,%f,%f)\n",op(getcolor(j)[1..3]));
                    ans := cat(ans,ans1);
                end do;
                ans := cat(ans,"},\n},\n}\n");
                return nprintf(ans);
            end proc;
            tab0 := op(argl);
            tab := colortab(tab0);
            inds := map(op,[indices(tab0)]);
            colname := "mycolor";
            f := colpar("div");
            col0,col1 := tab[0],tab[1000];
            scale := 1.0;
        end module;
        return md;
    end proc;

    colormap0 := proc(s)
    uses StringTools;
    local x;
        colname := LowerCase(convert(s,'string'));
        tab := table();
        f := x->x/sqrt(1+x^2);
        if(colname="viridis") then
            tab[1000] := [253,231,37];
            tab[750] := [94,201,98];
            tab[500] := [33,145,140];
            tab[350] := [59,82,139];
            tab[0] := [68,1,84];
        elif(colname="virdiv") then
            tab[1000] := [253,231,37];
            tab[750] := [94,201,98];
            tab[500] := [33,145,140];
            tab[350] := [59,82,139];
            tab[0] := [68,1,84];
            f := colpar('div');
        elif(colname="species1") then
            tab[0] := [0,110,29];
            tab[250] := [77,172,115];
            tab[500] := [179,220,195];
            tab[750] := [255,255,204];
            tab[1000] := [255,204,128];
        elif(colname="species2") then
            tab[1000] := map(round,[1.0,.8,.502]*255);
            tab[800] := map(round,[1.0,1.0,.8]*255);
            tab[600] := map(round,[.702,.863,.765]*255);
            tab[400] := map(round,[.302,.6745,.451]*255);
            tab[200] := map(round,[0.0,.431,.114]*255);
            tab[0] := map(round,[0.0,.3,.05]*255);
        elif(colname="bioheat") then
            tab[0] := [25,74,0];
            tab[167] := [106,141,21];
            tab[333] := [152,182,24];
            tab[500] := [225,235,93];
            tab[667] := [243,172,51];
            tab[833] := [239,127,40];
            tab[1000] := [223,58,27];
        elif(colname="bioheata") then
            tab[1000] := [25,74,0];
            tab[833] := [106,141,21];
            tab[677] := [152,182,24];
            tab[500] := [225,235,93];
            tab[333] := [243,172,51];
            tab[167] := [239,127,40];
            tab[0] := [223,58,27];
        elif(colname="ising") then
            tab[0] := [101,127,214];
            tab[167] := [125,139,200];
            tab[333] := [148,145,174];
            tab[500] := [171,152,145];
            tab[667] := [193,162,108];
            tab[833] := [222,167,77];
            tab[1000] := [234,177,61];
        elif(colname="seq1") then
            tab[0] := [214,238,245];
            tab[250] := [157,195,210];
            tab[500] := [115,167,185];
            tab[750] := [71,135,158];
            tab[1000] := [34,110,135];
        elif(colname="div1") then
            tab[0] := [67,139,130];
            tab[250] := [167,194,164];
            tab[500] := [241,236,216];
            tab[750] := [231,168,119];
            tab[1000] := [195,94,52];
            f := colpar('div');
        else
            error;
        end if;
        md := colormap(eval(tab));
        md:-f := f;
        if(nargs=2) then
            md:-f := colpar(args[2]);
        end if;
        md:-colname := colname;
        return md;
    end proc;

#color parametrization. allowable range maps into 0..1.
    colpar := proc(typ)
    local x;
        if(not type(typ,'string')) then
            return colpar(typeform(typ));
        elif(typ="div") then
            return x->(1+x/sqrt(1+x^2))/2;
        elif(typ="heat") then
            return x->x/sqrt(1+x^2);
        elif(typ="lin") then
            if(nargs=2) then
                rng := args[2];
            else
                rng := 0..1;
            end if;
            a,b := op(rng);
            return x->((x-a)/(b-a));
        else
            error "no such color parametrization";
        end if;
    end proc;

    colortab := proc(tab)
        il := sort(map(op,[indices(tab)]));
        if(min(il)<>0 or max(il)<>1000) then
            error "must have endpoints of 0,1000";
        end if;
        l := nops(il);
        ans := table();
        ans[0] := Color(tab[0]);
        for j from 2 to l do
            i1,i2 := il[j-1],il[j];
            col0,col1 := Color(tab[i1]),Color(tab[i2]);
            for i from i1+1 to i2 do
                t := evalf((i-i1)/(i2-i1));
                ans[i] := Color((1-t)*col0[1..3]+t*col1[1..3]);
            end do;
        end do;
        return eval(ans);
    end proc;

    colorvec := proc(tab,bg0,bg1)
        if(not type(args[1],'table')) then
            return colorvec(colorvec0(args[1]),args[2..nargs]);
        end if;
        il := sort(map(op,[indices(tab)]));
        l := nops(il);
        if(l=0) then
            error;
        end if;
        if(nargs=1) then
            return colorvec(tab,tab[il[1]]);
        elif(nargs=2) then
            return colorvec(tab,bg0,tab[il[l]]);
        end if;
        V := Vector(1000);
        for i from 1 to il[1]-1 do
            V[i] := Color(bg0);
        end do;
        for j from 1 to l-1 do
            i1,i2 := il[j],il[j+1];
            col0,col1 := Color(tab[i1]),Color(tab[i2]);
            for i from i1 to i2 do
                t := evalf((i-i1)/(i2-i1));
                V[i] := Color((1-t)*col0[1..3]+t*col1[1..3]);
            end do;
        end do;
        for i from il[l]+1 to 1000 do
            V[i] := Color(bg1);
        end do;
        return V;
    end proc;

end module;
