with(ColorTools);
loadlib("TensMaps");

ColMaps := module()
option package;
export colortab,colorvec,colormap,revcolor,tikzcolor,heatmap,discmap,solidmap;
local maxfinite;

    colormap := proc(colvec)
        if(whattype(colvec)='ColMap') then
            return colvec;
        elif(not type(args[1],'Vector')) then
            return colormap(colorvec(args));
        end if;
        md := module()
        option object;
        export `?[]`,`whattype`,colvec,getcol,colmap,col1,col2,init,alphamap,getalpha;
        local ModulePrint;
            ModulePrint::static := proc()
                print(seq(getcolor(250*i),i=0..4));
                return nprintf("color map");
            end proc;
            ModuleApply::static := proc(x)
                return colmap(x);
            end proc;
            `whattype`::static := proc()
                return 'ColMap';
            end proc;
            `?[]`::static := proc()
                return getcolor(op(args[2]));
            end proc;
            colmap := proc(x)
                i :=  round(1000*x);
                return getcolor(i);
            end proc;
            getcolor::static := proc(i)
                if(i<0) then
                    return col1;
                elif(i>1000) then
                    return col2;
                end if;
                return colvec[i+1];
            end proc;
            getalpha := proc(i)
                if(i>=0 and i<=1000) then
                    return 0.0;
                else
                    return 1.0;
                end if;
            end proc;
            alphamap::static := proc(x)
                i := round(1000*x);
                return getalpha(i);
            end proc;
            init::static := proc(V)
                colvec := Vector(V);
                col1 := Color([255,255,255]);
                col2 := Color([255,255,255]);
            end proc;
        end module;
        md:-init(colvec);
        return md;
    end proc;

    revcolor := proc(cmap)
        if(whattype(cmap)<>ColMap) then
            return revcolor(colormap(cmap));
        end if;
        colvec := cmap:-colvec;
        ans := colormap(Vector([seq(colvec[1001-i+1],i=1..1001)]));
        ans:-col1 := cmap:-col1;
        ans:-col2 := cmap:-col2;
        return ans;
    end proc;

#create the tikz code to generate the color using rng as endpoints
    tikzcolor := proc(cmap,cname,rng)
        if(type(args[3],'numeric')) then
            l := args[3];
            rng1 := [seq(round(1000*(i-1)/(l-1),i=1..l))];
            return procname(cmap,cname,rng1);
        end if;
        ans := "\\pgfplotsset{\n";
        ans := cat(ans,"colormap/",cmname,"/.style={/pgfplots/colormap={",cname,"}{\n");
        l := nops(rng);
        for i in rng do
            ans1 := nprintf("rgb=(%f,%f,%f)\n",op(getcolor(i)[1..3]));
            ans := cat(ans,ans1);
        end do;
        ans := cat(ans,"},\n},\n}\n");
        return nprintf(ans);
    end proc;

    colortab := proc(s)
    uses StringTools;
    local x;
        cmname := LowerCase(convert(s,'string'));
        tab := table();
        if(cmname="viridis") then
            tab[1000] := [253,231,37];
            tab[750] := [94,201,98];
            tab[500] := [33,145,140];
            tab[350] := [59,82,139];
            tab[0] := [68,1,84];
        elif(cmname="viridisa") then
            tab[0] := [253,231,37];
            tab[250] := [94,201,98];
            tab[500] := [33,145,140];
            tab[650] := [59,82,139];
            tab[1000] := [68,1,84];
        elif(cmname="virdiv") then
            tab[1000] := [253,231,37];
            tab[750] := [94,201,98];
            tab[500] := [33,145,140];
            tab[350] := [59,82,139];
            tab[0] := [68,1,84];
        elif(cmname="species1") then
            tab[0] := [0,110,29];
            tab[250] := [77,172,115];
            tab[500] := [179,220,195];
            tab[750] := [255,255,204];
            tab[1000] := [255,204,128];
        elif(cmname="grayblack") then
            tab[0] := [230,230,230];
            tab[1000] := [0,0,0];
        elif(cmname="blackbody") then
            tab[1000] := [255,255,255];
            tab[890] := [230,230,53];
            tab[580] := [227,105,5];
            tab[390] := [178,34,34];
            tab[0] := [0,0,0];
        elif(cmname="species2") then
            tab[1000] := map(round,[1.0,.8,.502]*255);
            tab[800] := map(round,[1.0,1.0,.8]*255);
            tab[600] := map(round,[.702,.863,.765]*255);
            tab[400] := map(round,[.302,.6745,.451]*255);
            tab[200] := map(round,[0.0,.431,.114]*255);
            tab[0] := map(round,[0.0,.3,.05]*255);
        elif(cmname="plasma") then
            tab[1000] := [240,249,33];
            tab[860] := [254,188,43];
            tab[710] := [244,136,73];
            tab[570] := [219,92,104];
            tab[430] := [185,50,137];
            tab[290] := [139,10,165];
            tab[140] := [84,2,163];
            tab[0] := [13,8,135];
        elif(cmname="math1") then
            tab[0] := [88,148,156];
            tab[333] := [237,230,211];
            tab[667] := [241,215,180];
            tab[1000] := [223,152,130];
        elif(cmname="math2") then
            tab[0] := [1,20,115];
            tab[250] := [149,172,216];
            tab[500] := [130,215,194];
            tab[750] := [238,217,150];
            tab[1000] := [182,140,56];
        elif(cmname="turquoise") then
            tab[0] := [208,206,157];
            tab[250] := [153,195,159];
            tab[500] := [32,153,144];
            tab[750] := [0,110,117];
            tab[1000] := [7,40,49];
        elif(cmname="chillblue") then
            tab[0] := [28,55,72];
            tab[250] := [39,91,113];
            tab[500] := [213,230,240];
            tab[750] := [240,248,251];
            tab[1000] := [15,29,38];
        elif(cmname="kitchen") then
            tab[0] := [204,183,136];
            tab[250] := [175,214,219];
            tab[500] := [37,38,84];
            tab[750] := [154,170,167];
            tab[1000] := [248,237,209];
        elif(cmname="bioheat") then
            tab[0] := [25,74,0];
            tab[167] := [106,141,21];
            tab[333] := [152,182,24];
            tab[500] := [225,235,93];
            tab[667] := [243,172,51];
            tab[833] := [239,127,40];
            tab[1000] := [223,58,27];
        elif(cmname="bioheata") then
            tab[1000] := [25,74,0];
            tab[833] := [106,141,21];
            tab[677] := [152,182,24];
            tab[500] := [225,235,93];
            tab[333] := [243,172,51];
            tab[167] := [239,127,40];
            tab[0] := [223,58,27];
        elif(cmname="ising") then
            tab[0] := [101,127,214];
            tab[167] := [125,139,200];
            tab[333] := [148,145,174];
            tab[500] := [171,152,145];
            tab[667] := [193,162,108];
            tab[833] := [222,167,77];
            tab[1000] := [234,177,61];
        elif(cmname="seq1") then
            tab[0] := [214,238,245];
            tab[250] := [157,195,210];
            tab[500] := [115,167,185];
            tab[750] := [71,135,158];
            tab[1000] := [34,110,135];
        elif(cmname="div1") then
            tab[0] := [67,139,130];
            tab[250] := [167,194,164];
            tab[500] := [241,236,216];
            tab[750] := [231,168,119];
            tab[1000] := [195,94,52];
        elif(cmname="div2") then
            tab[0] := map(round,[.23,.299,.754]*255);
            tab[500] := map(round,[.865,.865,.865]*255);
            tab[1000] := map(round,[.706,.334,.046]*255);
        elif(cmname="bw") then
            tab[0] := [0,0,0];
            tab[1000] := [255,255,255];
        elif(cmname="wb") then
            tab[0] := [255,255,255];
            tab[1000] := [0,0,0];
        elif(cmname="bb") then
            tab[0] := [0,0,0];
            tab[1000] := [0,0,0];
        else
            c := map(round,[Color(s)[]]*255);
            tab[0] := c;
            tab[1000] := c;
        end if;
        return tab;
    end proc;

    colorvec := proc(tab)
        if(not type(args[1],'table')) then
            return colorvec(colortab(args));
        end if;
        il := sort(map(op,[indices(tab)]));
        if(min(il)<>0 or max(il)<>1000) then
            error "must have endpoints of 0,1000";
        end if;
        l := nops(il);
        ans := Vector(1001);
        ans[1] := Color(tab[0]);
        for j from 2 to l do
            i1,i2 := il[j-1],il[j];
            col0,col1 := Color(tab[i1]),Color(tab[i2]);
            for i from i1+1 to i2 do
                t := evalf((i-i1)/(i2-i1));
                ans[i+1] := Color((1-t)*col0[1..3]+t*col1[1..3]);
            end do;
        end do;
        return ans;
    end proc;

    heatmap := module()
    option object;
    export getcolor,setcolor,draw,cmap;
    local ModuleApply,ModulePrint;
        ModulePrint::static := proc()
            print(cmap);
            return nprintf("draws a heat map");
        end proc;
        draw::static := proc(B,r)
            m,n := Dimension(B);
            if(nargs=1 or r=true) then
                a0,a1 := min(B),maxfinite(B);
            elif(type(r,'numeric')) then
                a0,a1 := min(B),r;
            elif(type(r,`..`)) then
                a0,a1 := op(r);
            end if;
            arr := Array(seq(1..k,k=[m,n,4]),datatype=float[8]);
            for i from 1 to m do
                for j from 1 to n do
                    c := (B[i,j]-a0)/(a1-a0);
                    arr[i,j,1],arr[i,j,2],arr[i,j,3] := cmap(c)[];
                    if(c<0 or c>1) then
                        arr[i,j,4] := 1.0;
                    end if;
                end do;
            end do;
            Embed(Create(arr));
            return arr;
        end proc;
        ModuleApply::static := draw;
        if(nargs=1 or args[2]=false) then
            return heatmap(B,0..1);
        end if;
        getcolor::static := proc()
            return cmap;
        end proc;
        setcolor::static := proc(col)
            cmap := colormap(col);
        end proc;
        setcolor('viridis');
    end module;

    discmap0 := proc(cl)
        k := nops(cl);
        colvec := Vector(1001);
        for i from 1 to 1001 do
            i1 := floor((i-1)*k/1001)+1;
            colvec[i] := cl[i1];
        end do;
        return colormap(colvec);
    end proc;

    discmap := proc(col,k)
        if(type(args[1],'list')) then
            return discmap0(args);
        end if;
        cmap := colormap(col);
        colvec0 := cmap:-colvec;
        colvec := Vector(1001);
        for i from 0 to 1000 do
            i1 := floor(1000/(k-1)*floor(k*i/1001))+1;
            colvec[i+1] := colvec0[i1];
        end do;
        return colormap(colvec);
    end proc;

    solidmap := proc(col)
        c := Color(col);
        return colormap(Vector([seq(c,i=1..1001)]));
    end proc;

end module;
