#include <stdio.h>
#define _XOPEN_SOURCE 600
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>

//#define USE_SIMD GCC
//#define SIMD_TEST

// non simd ballroom_640_480 : 556 ms
//     simd ballroom_640_480 : 230 ms
//          disk access only : 199 ms
//
//  Y : psnr : min=44.60214dB avg=44.89237dB max=47.56301dB
// Cb : psnr : min=45.04994dB avg=45.33798dB max=48.08658dB
// Cr : psnr : min=45.06935dB avg=45.40971dB max=48.11539dB

bool verbose = false;

class FrameType {
  public:
  char *tag;
  int size;
  double psnr_min, psnr_avg, psnr_max;
  unsigned long long ssd;
  int dataCount;

  FrameType( char *tag, int size ) {
    this->tag = tag;
    this->size = size;
    this->psnr_min = DBL_MAX;
    this->psnr_avg = 0;
    this->psnr_max = DBL_MIN;
    this->ssd = 0;
    this->dataCount = 0;
  }

  void Data( unsigned long long ssd ) {
    double psnr;
    if( ssd == 0 ) psnr = 1e100;
    else psnr = 10.0 * log10( 255.0*255.0 / ((double)ssd / (double)size) );

    if( psnr < psnr_min ) psnr_min = psnr;
    if( psnr > psnr_max ) psnr_max = psnr;
    psnr_avg += psnr;
    this->ssd += ssd;
    dataCount++;
        
    if( verbose ) {
      printf( "%s=%-8.7g ", tag, psnr );
    }
  }

  void Print() {
    psnr_avg /= dataCount;
    double psnr_total;
    if( ssd == 0 ) psnr_total = 1e100;
    else psnr_total = 10.0 * log10( 255.0*255.0 / ((double)ssd / ((double)dataCount * (double)size) ) );

    printf( "%6s : min=%-8.7g avg=%-8.7g max=%-8.7g total=%-8.7g\n",
      tag,
      psnr_min,
      psnr_avg,
      psnr_max,
      psnr_total
    );
  }
};

struct Noise {
  unsigned long long ssd;
};

void CompareFrames( unsigned char *frameOne, unsigned char *frameTwo, int size, struct Noise *ret_noise );
unsigned char *aligned_malloc( int size );
void aligned_free( unsigned char *ptr );

#ifdef USE_SIMD
unsigned char *shuffle_mask_1;
#endif

int
main( int argc, char **argv ) {
  FILE *fileOne, *fileTwo;

  if( argc < 4 ) {
    fprintf( stderr, "yuv_psnr one.yuv two.yuv WIDTHxHEIGHT [-v]\n" );
    return( 1 );
  }

  struct Noise ret_noise;
  #ifdef USE_SIMD
  shuffle_mask_1 = aligned_malloc( 16 );
  for( int i = 0; i < 16; ) {
    // start at lsw
    shuffle_mask_1[ i++ ] = i >> 1;
    shuffle_mask_1[ i++ ] = 128;
  }
  #endif

  #ifdef SIMD_TEST
  {
    int test_size = 1000001;
    unsigned long long ssd_correct = (unsigned long long)test_size;
    ssd_correct *= 255 * 255;
    unsigned char *test_white = aligned_malloc( test_size );
    unsigned char *test_black = aligned_malloc( test_size );
    for( int i = 0; i < test_size; i++ ) {
      test_white[ i ] = 255;
      test_black[ i ] = 0;
    }
    CompareFrames( test_white, test_black, test_size, &ret_noise );
    printf( "ret_noise.ssd = %llu (correct = %llu) %s\n", 
      ret_noise.ssd,
      ssd_correct,
      ( ret_noise.ssd == ssd_correct ) ? "passed" : "FAILED"
      );
    ssd_correct = 0;
    int x = 0, d = 1;
    for( int i = 0; i < test_size; i++ ) {
      test_white[ i ] = 255 - x;
      test_black[ i ] = x;
      int diff = 255 - x*2;
      ssd_correct += diff * diff;
      x += d;
      if( x == 0 ) d = 1;
      else if( x == 255 ) d = -1;
    }
    CompareFrames( test_white, test_black, test_size, &ret_noise );
    printf( "ret_noise.ssd = %llu (correct = %llu) %s\n", 
      ret_noise.ssd,
      ssd_correct,
      ( ret_noise.ssd == ssd_correct ) ? "passed" : "FAILED"
      );
    aligned_free( test_white );
    aligned_free( test_black );
    return( 0 );
  }
  #endif

  /*
  #ifdef USE_SIMD
  #if USE_SIMD==GCC
  printf( "USE_SIMD=GCC\n" );
  #endif
  #endif
  */

  if( (fileOne = fopen( argv[ 1 ], "rb" )) == NULL ) {
    fprintf( stderr, "Could not open '%s'\n", argv[ 1 ] );
    return( 1 );
  }
  
  if( (fileTwo = fopen( argv[ 2 ], "rb" )) == NULL ) {
    fprintf( stderr, "Could not open '%s'\n", argv[ 2 ] );
    return( 1 );
  }
  
  int width, height;
  sscanf( argv[3], "%dx%d", &width, &height );

  if( argc >= 5 ) {
    if( strcmp( argv[ 4 ], "-v" ) == 0 ) verbose = true;
  }

  if( verbose ) {
    printf( "width  = %d\n", width );
    printf( "height = %d\n", height );
  }

  // psnr = 10 * log10( 255^2 / mse^2 )
  // mse = ssd/N

  // total mse = (ssd1 + ssd2 + ssd3 + ... )/((Ny + 2*Nc)*f)

  int size_luma = width * height;
  int size_chroma = size_luma >> 2;
  int size_frame = size_luma + 2*size_chroma;
  // for some strange reason, reading in more data into a larger buffer
  // slows down the program slightly.
  int frameBufferSize = size_frame * 1;
  unsigned char *frameOne = aligned_malloc( frameBufferSize );
  unsigned char *frameTwo = aligned_malloc( frameBufferSize );

  FrameType *frameType[ 4 ];
  frameType[ 0 ] = new FrameType( "Y", size_luma );
  frameType[ 1 ] = new FrameType( "Cb", size_chroma );
  frameType[ 2 ] = new FrameType( "Cr", size_chroma );
  frameType[ 3 ] = new FrameType( "Frame", size_luma + 2*size_chroma );

  bool done = false;
  int frameIndex;
  for( frameIndex = 0; !done; ) {
    int bytesRead, bytesReadTwo;
    bytesRead = fread( frameOne, 1, frameBufferSize, fileOne );
    bytesReadTwo = fread( frameTwo, 1, frameBufferSize, fileTwo );
    if( bytesReadTwo < bytesRead ) bytesRead = bytesReadTwo;
    // has at least one of the files reached it's end?
    if( bytesRead < frameBufferSize ) done = true;

    // foreach complete frame read into the buffer
    for( int offset = 0; offset <= bytesRead - size_frame; frameIndex++ ) {
      unsigned long long ssdFrame = 0;

      // foreach component of the frame
      if( verbose ) printf( "[%4d] : ", frameIndex );

      for( int i = 0; i < 3; i++ ) {
        int size = frameType[ i ]->size;
        CompareFrames( frameOne + offset, frameTwo + offset, size, &ret_noise );
        offset += size;

        // keep track of some statistics on psnr
        ssdFrame += ret_noise.ssd;
        frameType[ i ]->Data( ret_noise.ssd );
      } // foreach frame component

      frameType[ 3 ]->Data( ssdFrame );
      if( verbose ) printf( "\n" );

    } // foreach complete frame in buffer
    if( done ) break; // don't increment frameIndex
  }
  if( frameIndex == 0 ) return( 2 );

  for( int i = 0; i < 4; i++ ) {
    frameType[ i ]->Print();
  }
  printf( "%6s : %d\n", "frames", frameIndex );

  aligned_free( frameOne );
  aligned_free( frameTwo );
  fclose( fileOne );
  fclose( fileTwo );
}

unsigned char *aligned_malloc( int size ) {
  unsigned char *ret_ptr = (unsigned char *)malloc( size + 16 );
  int temp = (unsigned long)ret_ptr & 0xF;
  int shift = 16 - temp;
  ret_ptr += shift;
  ret_ptr[ -1 ] = shift;
  return( ret_ptr );
}

void aligned_free( unsigned char *ptr ) {
  ptr -= ptr[ -1 ];
  free( ptr );
}

void
CompareFrames( unsigned char *frameOne, unsigned char *frameTwo, int size, struct Noise *ret_noise ) {
  unsigned long long ssd = 0;
  #ifdef USE_SIMD
  #if USE_SIMD==GCC
  int blockCount = size >> 4;
  // worst case: 255 diff at every pixel: (2^32-1)/255/255/16 = 4,128.188
  // Confirmed by testing. 4129 overflows, but 4128 doesn't. That means there
  // is no sign extension when movd stores into xmm7. Of course! Why would it?
  // Awesome!
  // 1 million blocks per y frame would be more than enough, that's 242*2^32 or
  // 363*2^32 per yuv420 frame. That leaves quite a few frames, even for the 
  // worst case
  asm (
    "\n    movl     %3, %%ebx"
    "\n    movdqa   (%%ebx), %%xmm4"       // shuffle_mask_1 in xmm4
    "\n    movl     %1, %%esi"
    "\n    movl     %2, %%edi"
    "\n    movl     %4, %%ebx"
    "\n    pxor     %%xmm5, %%xmm5"     // 64-bit final sum
    "\n    SSE_LOOP_START_%=:"
    "\n    pxor     %%xmm7, %%xmm7"     // clear intermediate sum
    "\n    movl     %%ebx, %%ecx"       // min(4128,ebx)->ecx
    "\n    cmp      $4128, %%ecx"
    "\n    jle      DONT_TRUNCATE_%="
    "\n    movl     $4128, %%ecx"       // limit to 4128 blocks
    "\n    DONT_TRUNCATE_%=:"
    "\n    subl     %%ecx, %%ebx"       // blocks remaining after next loop (ebx)
    "\n    SSE_LOOP_%=:"
    "\n    movq     (%%esi), %%xmm0"
    "\n    movq     (%%edi), %%xmm2"
    "\n"
    "\n    pshufb   %%xmm4, %%xmm0"
    "\n    pshufb   %%xmm4, %%xmm2"
    "\n"
    "\n    psubw    %%xmm2, %%xmm0"
    "\n    pmaddwd  %%xmm0, %%xmm0"
    "\n    paddd    %%xmm0, %%xmm7"
    "\n"
    "\n    movq     8(%%esi), %%xmm0"
    "\n    movq     8(%%edi), %%xmm2"
    "\n"
    "\n    pshufb   %%xmm4, %%xmm0"
    "\n    pshufb   %%xmm4, %%xmm2"
    "\n"
    "\n    psubw    %%xmm2, %%xmm0"
    "\n    pmaddwd  %%xmm0, %%xmm0"
    "\n    paddd    %%xmm0, %%xmm7"
    "\n"
    "\n    addl     $16, %%esi"
    "\n    addl     $16, %%edi"
    "\n"
    "\n    loop     SSE_LOOP_%="
    "\n"
    "\n    phaddd   %%xmm7, %%xmm7"     // horizontally add all dwords
    "\n    phaddd   %%xmm7, %%xmm7"
    "\n    movd     %%xmm7, %%ecx"
    "\n    movd     %%ecx,  %%xmm7"     // zero upper 96 bits of xmm7
    "\n    paddq    %%xmm7, %%xmm5"     // add to final 64 bit sum: xmm5
    "\n    cmpl     $0, %%ebx"
    "\n    jg       SSE_LOOP_START_%="  // process next 4128 chunk of blocks
    "\n"
    "\n    movq     %%xmm5, %0"         // final sum in xmm5 qword
    "\n"
    : "=m" (ssd)     // output
    : "m" (frameOne), "m" (frameTwo), "m" (shuffle_mask_1), "m" (blockCount)
    : "%ebx", "%ecx", "%esi", "%edi" // clobbers these registers
  );
  // compute any trailing bytes (if size is not divisible by 16)
  int remainder = size & 0xF;
  if( remainder ) {
    frameOne += (size & 0xFFFFFFF0);
    frameTwo += (size & 0xFFFFFFF0);
    for( int i = 0; i < remainder; i++, frameOne++, frameTwo++ ) {
      int diff = (int)*frameOne - (int)*frameTwo;
      ssd += diff*diff;
    }
  }
  #endif // GCC
  #else
  for( int i = 0; i < size; i++, frameOne++, frameTwo++ ) {
    int diff = (int)*frameOne - (int)*frameTwo;
    ssd += diff*diff;
  }
  #endif
  ret_noise->ssd = ssd;
}

