import tNxN, numpy, dct

def fdct8x8(b):
    '''
    Forward 8x8 DCT.
    '''
    return tNxN.fdctNxN(b)    
    
def idct8x8(b):
    '''
    Inverse 8x8 DCT.
    '''
    return tNxN.idctNxN(b)

def fdct8x8x8(b):
    '''
    Forward 8x8x8 DCT.
    '''
    d = numpy.empty((8, 8, 8), numpy.float)
    for i in xrange(8):
        d[i,:,:] = fdct8x8(b[i,:,:])    
    for r in xrange(8):
        for c in xrange(8):
            d[:,r,c] = dct.fdct(d[:,r,c])            
    return d
    
def idct8x8x8(b):
    '''
    Inverse 8x8x8 DCT.
    '''
    d = numpy.empty((8, 8, 8), numpy.float)
    for r in xrange(8):
        for c in xrange(8):
            d[:,r,c] = dct.idct(b[:,r,c])            
    for i in xrange(8):
        d[i,:,:] = idct8x8(d[i,:,:])    
    return d

def do_f8x8(s):
    for r in xrange(0, s.shape[0], 8):
        for c in xrange(0, s.shape[1], 8): 
            s[r:r+8,c:c+8] = fdct8x8(s[r:r+8,c:c+8])
    return s

def _get_zz():
    zz = []
    for i, j in dct.zigzag(numpy.empty((8,8))):
        zz.append(i*8 + j)
    return zz

def build_vector_from_8x8(s):
    v = numpy.zeros((64, (s.shape[0]//8)*(s.shape[1]//8)), numpy.float)
    i = 0
    zz = _get_zz()
    for r in xrange(0, s.shape[0], 8):
        for c in xrange(0, s.shape[1], 8): 
            v[:,i] = s[r:r+8,c:c+8].flat[zz]
            i += 1
    return v

def do_i8x8(s):
    for r in xrange(0, s.shape[0], 8):
        for c in xrange(0, s.shape[1], 8): 
            s[r:r+8,c:c+8] = idct8x8(s[r:r+8,c:c+8])
    return s

def set_8x8_from_vector(s, v):
    i = 0
    zz = _get_zz()
    for r in xrange(0, s.shape[0], 8):
        for c in xrange(0, s.shape[1], 8): 
            s[r:r+8,c:c+8].flat[zz] = v[:,i]
            i += 1
    return s            

