with(LinearAlgebra):
loadlib("LinAlg");

TensMaps := module()
option package;
export allocarr,arrdim,arrtype,ordmaps,tensmaps,arr2vec,vec2arr,tensrng,vecpatches,endrange,rastrange,map2arr,map2mat,gridmaps,tensgrid,rowmaj,gridmap;
local tensord0,tensord1,tensmaps0,tensmaps1;

    arrdim := proc(A)
        return seq(op(2,rng),rng=ArrayTools:-Dimensions(A));
    end proc;

    allocarr := proc()
        if(not type(procname,indexed)) then
            return allocarr[float[8]](args);
        end if;
        typ := op(procname);
        ans := [];
        for ml in args do
            ans := [op(ans),Array(seq(1..m,m=ml),datatype=typ)];
        end do;
        return op(ans);
    end proc;

    arrtype := proc(arr)
        typl1 := ["Array","Vector","Matrix"];
        typl2 := [float[8],integer[4]];
        for typ1 in typl1 do
            if(type(arr,parse(typ1))) then
                for typ2 in typl2 do
                    if(type(arr,parse(cat(typ1,"(datatype=",typ2,")")))) then
                        return typ2;
                    end if;
                end do;
            end if;
        end do;
        error;
    end proc;

    tensrng := proc()
        rng0 := args[1];
        rng := [];
        for r in rng0 do
            if(type(r,'numeric')) then
                rng := [op(rng),1..r];
            else
                rng := [op(rng),r];
            end if;
        end do;
        d := nops(rng);
        al := [seq(op(1,r),r=rng)];
        bl := [seq(op(2,r),r=rng)];
        ml := [seq(bl[i]-al[i]+1,i=1..d)];
        N := convert(ml,`*`);
        return al,bl,ml,N;
    end proc;

    ordmap0 := proc()
        al,bl,ml,N := tensrng(args[1]);
        d := nops(al);
        code := "proc(j)\n";
        code := cat(code,"option inline;\n");
        code := cat(code,"return [");
        for i from 1 to d-1 do
            m1 := mul(ml[j],j=i..d);
            m2 := mul(ml[j],j=i+1..d);
            code := cat(code,al[i],"+iquo(irem(j-1,",m1,"),",m2,"),");
        end do;
        code := cat(code,al[d],"+irem(j-1,",ml[d],")];\n");
        code := cat(code,"end proc;\n");
        return parse(code);
    end proc;

    ordmap1 := proc()
        al,bl,ml,N := tensrng(args[1]);
        d := nops(al);
        cl := [seq(mul(ml[j],j=i+1..d),i=1..d)];
        c0 := 1-add(al[i]*cl[i],i=1..d);
        code := "proc(xl)\n";
        code := cat(code,"option inline;\n");
        code := cat(code,"return ",c0);
        for i from 1 to d do
            code := cat(code,"+",cl[i],"*xl[",i,"]");
        end do;
        code := cat(code,";\n");
        code := cat(code,"end proc;\n");
        return parse(code);
    end proc;

    ordmaps := proc(rng)
        return ordmap0(rng),ordmap1(rng);
    end proc;

    tensmaps0 := proc(rng,dtype)
        al,bl,ml,N := tensrng(rng);
        d := nops(ml);
        code := "proc(vec::Array(datatype=";
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"),arr::Array(datatype=");
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"))\n");
        code := cat(code,"k := 0;\n");
        for j from 1 to d do
            a,b,m := al[j],bl[j],ml[j];
            code := cat(code,"for i",j," from ",a," to ",b," do\n");
        end do;
        code := cat(code,"k := k+1;\n");
        code := cat(code,"arr[");
        code := cat(code,"i1");
        for j from 2 to d do
            code := cat(code,",i",j);
        end do;
        code := cat(code,"] := vec[k];\n");
        for a from 1 to d do
            code := cat(code,"end do;\n");
        end do;
        code := cat(code,"end proc;\n");
        return Compiler:-Compile(parse(code));
    end proc;

    tensmaps1 := proc(rng,dtype)
        al,bl,ml,N := tensrng(rng);
        d := nops(ml);
        code := "proc(arr::Array(datatype=";
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"),vec::Array(datatype=");
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"))\n");
        code := cat(code,"k := 0;\n");
        for j from 1 to d do
            a,b,m := al[j],bl[j],ml[j];
            code := cat(code,"for i",j," from ",a," to ",b," do\n");
        end do;
        code := cat(code,"k := k+1;\n");
        code := cat(code,"vec[k] := arr[");
        code := cat(code,"i1");
        for j from 2 to d do
            code := cat(code,",i",j);
        end do;
        code := cat(code,"];\n");
        for a from 1 to d do
            code := cat(code,"end do;\n");
        end do;
        code := cat(code,"end proc;\n");
        return Compiler:-Compile(parse(code));
    end proc;

    tensmaps := proc(rng,dtype)
        return tensmaps0(rng,dtype),tensmaps1(rng,dtype);
    end proc;

    arr2vec0 := proc(d,dtype)
    option remember;
        code := "proc(arr::Array(datatype=";
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"),vec::Array(datatype=");
        code := cat(code,convert(dtype,'string'));
        code := cat(code,")");
        for j from 1 to d do
            code := cat(code,",a",j,"::integer[4]");
        end do;
        for j from 1 to d do
            code := cat(code,",b",j,"::integer[4]");
        end do;
        code := cat(code,")\n");
        code := cat(code,"k := 0;\n");
        for j from 1 to d do
            code := cat(code,"for i",j," from a",j," to b",j," do\n");
        end do;
        code := cat(code,"k := k+1;\n");
        code := cat(code,"vec[k] := arr[");
        code := cat(code,"i1");
        for j from 2 to d do
            code := cat(code,",i",j);
        end do;
        code := cat(code,"];\n");
        for a from 1 to d do
            code := cat(code,"end do;\n");
        end do;
        code := cat(code,"end proc;\n");
        return Compiler:-Compile(parse(code));
    end proc;

    arr2vec := proc(arr,rng:=[arrdim(arr)])
    local vec;
        al,bl,ml,N := tensrng(rng);
        d := nops(al);
        typ := arrtype(arr);
        vec := allocla[typ](N);
        F := arr2vec0(d,typ);
        F(arr,vec,op(al),op(bl));
        return vec;
    end proc;

    vec2arr0 := proc(d,dtype)
    option remember;
        code := "proc(vec::Array(datatype=";
        code := cat(code,convert(dtype,'string'));
        code := cat(code,"),arr::Array(datatype=");
        code := cat(code,convert(dtype,'string'));
        code := cat(code,")");
        for j from 1 to d do
            code := cat(code,",a",j,"::integer[4]");
        end do;
        for j from 1 to d do
            code := cat(code,",b",j,"::integer[4]");
        end do;
        code := cat(code,")\n");
        code := cat(code,"k := 0;\n");
        for j from 1 to d do
            code := cat(code,"for i",j," from a",j," to b",j," do\n");
        end do;
        code := cat(code,"k := k+1;\n");
        code := cat(code,"arr[");
        code := cat(code,"i1");
        for j from 2 to d do
            code := cat(code,",i",j);
        end do;
        code := cat(code,"] := vec[k];\n");
        for a from 1 to d do
            code := cat(code,"end do;\n");
        end do;
        code := cat(code,"end proc;\n");
        return Compiler:-Compile(parse(code));
    end proc;

    vec2arr := proc(vec,rng)
        al,bl,ml,N := tensrng(rng);
        d := nops(al);
        typ := arrtype(vec);
        arr := allocarr[typ](ml);
        F := vec2arr0(d,typ);
        F(vec,arr,op(al),op(bl));
        return arr;
    end proc;

    vecpatches := proc(arr,l)
        ml := [arrdim(arr)];
        d := nops(ml);
        rng1 := [seq(l+1..ml[i]-l,i=1..d)];
        rng2 := [seq(-l..l,i=1..d)];
        f1,g1 := ordmaps(rng1);
        f2,g2 := ordmaps(rng2);
        N1,N2 := tensrng(rng1)[4],tensrng(rng2)[4];
        ans := matf(N1,N2);
        for i1 from 1 to N1 do
            xl1 := f1(i1);
            for i2 from 1 to N2 do
                xl2 := f2(i2);
                xl := xl1+xl2;
                ans[i1,i2] := arr[op(xl)];
            end do;
        end do;
        return ans;
    end proc;

    map2arr0 := proc(f,rng,arr)
        d := nops(rng);
        ml := [arrdim(arr)];
        ml := rastrange(rng);
        al,bl,ml,N := tensrng(arrdim(arr));
        cl,dl := endrange(rng);
    end proc;

    map2arr := module()
    option object;
    export N;
    local ModulePrint,ModuleApply;
        getarr::static := proc(f,rng,dims:=N)
            d := nops(rng);
            if(type(args[3],'Array') or type(args[3],'Matrix')) then
                arr := args[3];
                ml,dxl := rastrange(rng,[arrdim(arr)]);
            else
                ml,dxl := rastrange(rng,args[3..nargs]);
                arr := allocarr(ml);
            end if;
            al,bl := endrange(rng);
            M := convert(ml,`*`);
            F,G := ordmaps(ml);
            for k from 1 to M do
                il := F(k);
                xx := [seq(al[j]+(il[j]-.5)*dxl[j],j=1..d)];
                arr[op(il)] := f(xx);
            end do;
            return arr;
        end proc;
        ModuleApply::static := getarr;
        N := 1000;
    end module;

    map2mat := proc(f,rng,dims)
    local p;
        x0,x1 := op(rng[1]);
        y0,y1 := op(rng[2]);
        rng1 := [y0..y1,x0..x1];
        f1 := p->f([p[2],y0+y1-p[1]]);
        ml := rastrange(args[2..nargs])[1];
        ml1 := [ml[2],ml[1]];
        A := matf(ml1[1],ml1[2]);
        map2arr(f1,rng1,A);
        return A;
    end proc;

    endrange := proc(rng)
        d := nops(rng);
        al := [seq(evalf(op(1,rng[i])),i=1..d)];
        bl := [seq(evalf(op(2,rng[i])),i=1..d)];
        return al,bl;
    end proc;

    #flag=true/false means use max/min as the reference size
    rastrange := proc(rng,M,flag:=true)
        d := nops(rng);
        al,bl := endrange(rng);
        cl := bl-al;
        if(type(M,'numeric')) then
            sig := sort(cl,output=permutation);
            if(flag) then
                i := sig[d];
            else
                i := sig[1];
            end if;
            ml := [seq(flag,i=1..d)];
            ml[i] := M;
            return procname(rng,ml);
        end if;
        ml := M;
        tl := [];
        for i from 1 to d do
            if(type(ml[i],'numeric')) then
                tl := [op(tl),evalf(cl[i]/ml[i])];
            end if;
        end do;
        t0,t1 := min(tl),max(tl);
        dxl := [seq(0.0,i=1..d)];
        for i from 1 to d do
            if(ml[i]=true) then
                ml[i] := ceil(cl[i]/t1);
            elif(ml[i]=false) then
                ml[i] := ceil(cl[i]/t0);
            end if;
            dxl[i] := evalf(cl[i]/ml[i]);
        end do;
        return ml,dxl;
    end proc;

    gridmap0 := proc(al,bl,ml)
        d := nops(al);
        code := "proc(ii)\n";
        code := cat(code,"option inline;\n");
        code := cat(code,"return [");
        for j from 1 to d do
            a := evalf(al[j]+.5*(al[j]-bl[j])/ml[j]);
            code := cat(code,a);
            b := evalf((bl[j]-al[j])/ml[j]);
            code := cat(code,plusnum(b));
            code := cat(code,"*ii[",j,"]");
            if(j<d) then
                code := cat(code,",");
            end if;
        end do;
        code := cat(code,"];\n");
        code := cat(code,"end proc;\n");
        return parse(code);
    end proc;

    gridmap1 := proc(al,bl,ml)
        d := nops(al);
        code := "proc(xx)\n";
        code := cat(code,"option inline;\n");
        code := cat(code,"return [");
        for j from 1 to d do
            a := evalf(ml[j]/(bl[j]-al[j]));
            b := evalf(-al[j]*ml[j]/(bl[j]-al[j]));
            code := cat(code,"ceil(",a,"*xx[",j,"]",plusnum(b),")");
            if(j<d) then
                code := cat(code,",");
            end if;
        end do;
        [seq(ceil((x[i]-al[i])/(bl[i]-al[i])*ml[i]),i=1..d)];
        code := cat(code,"];\n");
        code := cat(code,"end proc;\n");
        return parse(code);
    end proc;

    plusnum := proc(b)
        if(b=0.0) then
            return "";
        elif(b<0.0) then
            return convert(b,'string');
        elif(b>0.0) then
            return cat("+",b);
        end if;
    end proc;

    gridmaps := proc(rng,M,flag:=true)
        al,bl := endrange(args);
        ml,dxl := rastrange(args);
        f := gridmap0(al,bl,ml);
        g := gridmap1(al,bl,ml);
        return f,g:
    end proc;

    tensgrid := proc(rng,M)
        md := module()
        export xx2ii,ii2xx,ii2ind,ind2ii,init,dims,d,N,rng,al,bl,map2vec,map2arr,map2mat,getcoords,`whattype`;
        local ModulePrint,ModuleApply;
            ModulePrint::static := proc()
                return nprintf(cat("[%d",seq(",%d",i=1..d-1),"]-dimensional tensor "
                                   "grid"),op(dims));
            end proc;
            `whattype`::static := proc()
                return 'TensGrid';
            end proc;
            map2vec::static := proc(f,V)
                if(nargs=1) then
                    return map2vec(f,vecf(N));
                end if;
                for k from 1 to N do
                    V[k] := f(ii2xx(ind2ii(k)));
                end do;
                return V;
            end proc;
            map2arr::static := proc(f,arr)
                if(nargs=1) then
                    return map2arr(f,allocarr[float[8]](dims));
                end if;
                for k from 1 to N do
                    ii := ind2ii(k);
                    arr[op(ii)] := f(ii2xx(ii));
                end do;
                return arr;
            end proc;
            map2mat::static := proc(f,A)
                m,n := dims[1],dims[2];
                if(nargs=1) then
                    return map2mat(f,matf(n,m));
                end if;
                for k from 1 to N do
                    ii := ind2ii(k);
                    A[n-ii[2]+1,ii[1]] := f(ii2xx(ii));
                end do;
                return A;
            end proc;
            getcoords::static := proc(i)
                if(nargs=0) then
                    return [seq(getcoords(i1),i1=1..d)];
                end if;
                a,b,m := al[i],bl[i],dims[i];
                return vecf([seq(a+(j-.5)*(b-a)/m,j=1..m)]);
            end proc;
            ModuleApply::static := xx2ii;
            init::static := proc()
                dims,dxl := rastrange(args);
                al,bl := endrange(args);
                d := nops(dims);
                N := convert(dims,`*`);
                ind2ii,ii2ind := ordmaps(dims);
                ii2xx,xx2ii := gridmaps(args);
            end proc;
        end module;
        md:-init(args);
        return md;
    end proc;

    gridmap := proc(f,rng,M)
        if(type(args[1],'list')) then
            return tensgrid(args);
        end if;
        md := tensgrid(args[2..nargs]);
        arr := md:-map2arr(f);
        return arr;
    end proc;

    rowmaj := proc(arr)
        m,n := arrdim(arr);
        ans := matf(n,m);
        for i from 1 to n do
            for j from 1 to m do
                ans[i,j] := arr[j,n-i+1];
            end do;
        end do;
        return ans;
    end proc;

end module;
