#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

#include <sp/spLib.h>
#include <sp/spMain.h>

#define FFTTEST_INV_FLAG 0

#if 1
#define FFT_PLUGIN_NAME "oourafft"
#elif 0
#define FFT_PLUGIN_NAME "fftw"
#elif 0
#define FFT_PLUGIN_NAME "cufft"
#else
#undef FFT_PLUGIN_NAME
#endif

static long analysisLoop(spDVector vec, long shiftl, long framel, long fftl, long max_nframe, double *difftime)
{
    long k;
    long pos;
    long nframe;
    double *real, *imag;
    time_t t0, t1;

    real = xspAlloc(fftl, double);
    imag = xspAlloc(fftl, double);

    t0 = clock();
    spDebug(10, "analysisLoop", "t0 = %f\n", (double)t0);
    
    pos = -framel/2;
    nframe = 0;

    while (pos < vec->length && (max_nframe <= 0 || nframe < max_nframe)) {
        for (k = 0; k < fftl; k++) {
            if (k < framel && pos + k < vec->length && pos + k >= 0) {
                real[k] = vec->data[pos + k];
            } else {
                real[k] = 0.0;
            }
            imag[k] = 0.0;
        }

        spfft(real, imag, fftl, 0);
        
        pos += shiftl;
        nframe++;
    }

    t1 = clock();
    spDebug(10, "analysisLoop", "t1 = %f, CLOCKS_PER_SEC = %f\n", (double)t1, (double)CLOCKS_PER_SEC);

    if (difftime != NULL) {
        *difftime = (double)(t1 - t0) / (double)CLOCKS_PER_SEC;
    }
    
    xspFree(real);
    xspFree(imag);

    return nframe;
}

static long analysisLoop2(spPlugin *plugin, spFFTPrecision precision, spDVector vec, long shiftl, long framel,
                          long fftl, long max_nframe, double *odist, double *difftime)
{
    long k;
    long pos;
    long nframe;
    long order;
    double *real, *imag;
    double *real2, *imag2;
    double rx, ix;
    double dist;
    time_t t0, t1;
    spFFTRec fftrec;

    order = spNextPow2(fftl);
    
    if ((fftrec = spInitFFTByPlugin(plugin, order, precision)) == NULL) {
        if (difftime != NULL) *difftime = 0.0;
        if (odist != NULL) *odist = 0.0;
        return -1;
    }
        
    real = xspAlloc(fftl, double);
    imag = xspAlloc(fftl, double);
    if (odist != NULL) {
        real2 = xspAlloc(fftl, double);
        imag2 = xspAlloc(fftl, double);
    }

    t0 = clock();
    spDebug(10, "analysisLoop2", "t0 = %f\n", (double)t0);
    
    pos = -framel/2;
    nframe = 0;
    dist = 0.0;

    while (pos < vec->length && (max_nframe <= 0 || nframe < max_nframe)) {
        for (k = 0; k < fftl; k++) {
            if (k < framel && pos + k < vec->length && pos + k >= 0) {
                real[k] = vec->data[pos + k];
            } else {
                real[k] = 0.0;
            }
            imag[k] = 0.0;
        }
        spExecFFT(fftrec, real, imag, 0);

        if (odist != NULL) {
            for (k = 0; k < fftl; k++) {
                if (k < framel && pos + k < vec->length && pos + k >= 0) {
                    real2[k] = vec->data[pos + k];
                } else {
                    real2[k] = 0.0;
                }
                imag2[k] = 0.0;
            }
            spfft(real2, imag2, fftl, 0);
            
            for (k = 0; k < fftl; k++) {
                rx = real[k] - real2[k];
                ix = imag[k] - imag2[k];
                dist += (SQUARE(rx) + SQUARE(ix));
            }
        }
        
        spDebug(50, "analysisLoop2", "done: nframe = %ld \n", nframe);
        pos += shiftl;
        nframe++;
    }

    t1 = clock();
    spDebug(10, "analysisLoop2", "t1 = %f, CLOCKS_PER_SEC = %f\n", (double)t1, (double)CLOCKS_PER_SEC);

    if (odist != NULL) {
        spDebug(-10, "analysisLoop2", "total dist = %f\n", dist);
        dist = sqrt(dist / (double)(fftl * nframe));
        *odist = dist;
    }
    if (difftime != NULL) {
        *difftime = (double)(t1 - t0) / (double)CLOCKS_PER_SEC;
    }
    
    xspFree(real);
    xspFree(imag);
    if (odist != NULL) {
        xspFree(real2);
        xspFree(imag2);
    }

    spFreeFFT(fftrec);

    return nframe;
}

static long analysisLoop3(spPlugin *plugin, spFFTPrecision precision, spDVector vec, long shiftl, long framel,
                          long fftl, long max_nframe, double *odist, double *difftime)
{
    int inv;
    long k;
    long pos;
    long nframe;
    long order;
    long fftl2;
    double *real;
    double *real2, *imag2;
    double rx, ix;
    double dist;
    time_t t0, t1;
    spFFTRec fftrec;

    inv = FFTTEST_INV_FLAG;
    
    order = spNextPow2(fftl);
    
    if ((fftrec = spInitFFTByPlugin(plugin, order, precision)) == NULL) {
        if (difftime != NULL) *difftime = 0.0;
        if (odist != NULL) *odist = 0.0;
        return -1;
    }
    fftl2 = fftl / 2;
        
    real = xspAlloc(fftl, double);
    real2 = xspAlloc(fftl, double);
    imag2 = xspAlloc(fftl, double);

    t0 = clock();
    spDebug(10, "analysisLoop3", "t0 = %f\n", (double)t0);
    
    pos = -framel/2;
    nframe = 0;
    dist = 0.0;

    while (pos < vec->length && (max_nframe <= 0 || nframe < max_nframe)) {
        spDebug(100, "analysisLoop3", "nframe = %ld\n", nframe);
        if (inv) {
            for (k = 0; k < fftl; k++) {
                if (k < framel && pos + k < vec->length && pos + k >= 0) {
                    real2[k] = vec->data[pos + k];
                } else {
                    real2[k] = 0.0;
                }
                imag2[k] = 0.0;
            }
            spExecFFT(fftrec, real2, imag2, 0);

            real[0] = real2[0];
            real[1] = real2[fftl2];
            for (k = 1; k < fftl2; k++) {
                real[2 * k] = real2[k];
                real[2 * k + 1] = imag2[k];
            }
        } else {
            for (k = 0; k < fftl; k++) {
                if (k < framel && pos + k < vec->length && pos + k >= 0) {
                    real[k] = vec->data[pos + k];
                } else {
                    real[k] = 0.0;
                }
            }
        }
        spExecRealFFT(fftrec, real, inv);

        if (odist != NULL) {
            if (!inv) {
                for (k = 0; k < fftl; k++) {
                    if (k < framel && pos + k < vec->length && pos + k >= 0) {
                        real2[k] = vec->data[pos + k];
                    } else {
                        real2[k] = 0.0;
                    }
                    imag2[k] = 0.0;
                }
            }
#if 1
            spfft(real2, imag2, fftl, inv);
#else
            spExecFFT(fftrec, real2, imag2, inv);
#endif

            if (inv) {
                for (k = 0; k < fftl; k++) {
                    rx = real2[k] - real[k];
                    dist += SQUARE(rx);
                    /*printf("%ld %f %f\n", k, real2[k], real[k]);*/
                }
            } else {
                rx = real[0] - real2[0]; ix = imag2[0]; 
                dist += (SQUARE(rx) + SQUARE(ix));
                rx = real[1] - real2[fftl2]; ix = imag2[fftl2]; 
                dist += (SQUARE(rx) + SQUARE(ix));
                /*printf("%ld %f %f %f %f\n", 0L, real2[0], imag2[0], real[0], 0.0);*/
                for (k = 1; k < fftl2; k++) {
                    rx = real[2*k] - real2[k];
                    ix = real[2*k+1] - imag2[k];
                    dist += (SQUARE(rx) + SQUARE(ix));
                    /*printf("%ld %f %f %f %f\n", k, real2[k], imag2[k], real[2*k], real[2*k+1]);*/
                }
                /*printf("%ld %f %f %f %f\n", fftl2, real2[fftl2], imag2[fftl2], real[1], 0.0);*/
            }
        }
        
        pos += shiftl;
        nframe++;
    }

    t1 = clock();
    spDebug(10, "analysisLoop3", "t1 = %f, CLOCKS_PER_SEC = %f\n", (double)t1, (double)CLOCKS_PER_SEC);

    if (odist != NULL) {
        spDebug(-10, "analysisLoop3", "total dist = %f\n", dist);
        /*dist = sqrt(dist / (double)(fftl * nframe));*/
        dist = sqrt(dist / (double)nframe);
        *odist = dist;
    }
    if (difftime != NULL) {
        *difftime = (double)(t1 - t0) / (double)CLOCKS_PER_SEC;
    }
    
    xspFree(real);
    xspFree(real2);
    xspFree(imag2);

    spFreeFFT(fftrec);

    return nframe;
}

static long analysisLoopBatched(spPlugin *plugin, spFFTPrecision precision, spDVector vec, long shiftl, long framel,
                                long fftl, long batch, long max_nframe, double *odist, double *difftime, spBool calc_power)
{
    int inv;
    long k, m;
    long pos, old_pos;
    long nframe;
    long order;
    long fftl2;
    long batch_index;
    double *data;
    double *data2;
    double rx;
    double dist;
    time_t t0, t1;
    spFFTRec fftrec;

    inv = FFTTEST_INV_FLAG;
    
    order = spNextPow2(fftl);
    
    if ((fftrec = spInitBatchedFFTByPlugin(plugin, order, batch, precision)) == NULL) {
        if (difftime != NULL) *difftime = 0.0;
        if (odist != NULL) *odist = 0.0;
        return -1;
    }
    fftl2 = fftl / 2;
        
    data = xspAlloc(fftl * batch, double);
    data2 = xspAlloc(fftl * batch, double);

    t0 = clock();
    spDebug(10, "analysisLoopBatched", "t0 = %f\n", (double)t0);
    
    pos = -framel/2;
    old_pos = pos;
    nframe = 0;
    batch_index = 0;
    dist = 0.0;

    while (pos < vec->length && (max_nframe <= 0 || nframe < max_nframe)) {
        spDebug(100, "analysisLoopBatched", "nframe = %ld\n", nframe);
        for (k = 0; k < fftl; k++) {
            if (k < framel && pos + k < vec->length && pos + k >= 0) {
                data[batch_index * fftl + k] = vec->data[pos + k];
            } else {
                data[batch_index * fftl + k] = 0.0;
            }
        }
        ++batch_index;
        
        if (batch_index == batch) {
            if (calc_power == SP_TRUE) {
                spExecFFTPower(fftrec, data, 1.0);
            } else {
                spExecRealFFT(fftrec, data, inv);
            }

            if (odist != NULL) {
                for (m = 0; m < batch; m++) {
                    for (k = 0; k < fftl; k++) {
                        if (k < framel && old_pos + k < vec->length && old_pos + k >= 0) {
                            data2[k] = vec->data[old_pos + k];
                        } else {
                            data2[k] = 0.0;
                        }
                    }
                    sprfft(data2, fftl, inv);
                    if (calc_power == SP_TRUE) {
                        rffttopower(data2, 1.0, fftl);
                    }
            
                    for (k = 0; k < fftl; k++) {
                        rx = data[m * fftl + k] - data2[k];
#if 0
                        printf("%ld %ld %ld %f %f\n", nframe, m, k, data[m * fftl + k], data2[k]);
#endif
                        dist += SQUARE(rx);
                    }
                    
                    old_pos += shiftl;
                }
            }
            
            batch_index = 0;
        }
        
        pos += shiftl;
        nframe++;
    }

    t1 = clock();
    spDebug(10, "analysisLoopBatched", "t1 = %f, CLOCKS_PER_SEC = %f\n", (double)t1, (double)CLOCKS_PER_SEC);

    if (odist != NULL) {
        spDebug(-10, "analysisLoopBatched", "total dist = %f\n", dist);
        dist = sqrt(dist / (double)nframe);
        *odist = dist;
    }
    if (difftime != NULL) {
        *difftime = (double)(t1 - t0) / (double)CLOCKS_PER_SEC;
    }
    
    xspFree(data);
    xspFree(data2);

    spFreeFFT(fftrec);

    return nframe;
}

static char plugin_name[SP_MAX_PATHNAME] = "";
static spBool help_flag;
static int debug_level = -1;
static long batch;
static long min_order;
static long max_order;

static spOptions options;
static spOption option[] = {
    {"-p", "-plugin", "FFT plugin name", "fft_plugin", 
         SP_TYPE_STRING_A, plugin_name,
#ifdef FFT_PLUGIN_NAME
         FFT_PLUGIN_NAME
#else
         NULL
#endif
    },
    {"-min", NULL, "minimum FFT order", "min_order", 
         SP_TYPE_LONG, &min_order, "8"},
    {"-max", NULL, "maximum FFT order", "max_order", 
         SP_TYPE_LONG, &max_order, "16"},
    {"-batch", NULL, "batch size", "batch", 
         SP_TYPE_LONG, &batch, "4"},
    {"-debug", NULL, "debug level", NULL,
         SP_TYPE_INT, &debug_level, NULL},
    {"-h", "-help", "display this message", NULL,
         SP_TYPE_BOOLEAN, &help_flag, SP_FALSE_STRING},
};

char *file_label[] = {
    "[FFT plugin]",
};

int spMain(int argc, char **argv)
{
    long k;
    double fs;
    double freq;
    double lengthm, shiftm, framem;
    long max_nframe;
    long length, shiftl, framel;
    long fftl;
    long order;
    spBool batch_calc_power;
    double difftime, quick_difftime, double_quick_difftime;
    double time_us, quick_time_us, double_quick_time_us;
    double quick_rate, double_quick_rate;
    double quick_dist, double_quick_dist;
    double quick_real_difftime, double_quick_real_difftime;
    double quick_real_time_us, double_quick_real_time_us;
    double quick_real_rate, double_quick_real_rate;
    double quick_real_dist, double_quick_real_dist;
    double quick_real_batched_difftime, double_quick_real_batched_difftime;
    double quick_real_batched_time_us, double_quick_real_batched_time_us;
    double quick_real_batched_rate, double_quick_real_batched_rate;
    double quick_real_batched_dist, double_quick_real_batched_dist;
    long nframe;
    spDVector vec;
    const char *filename;
    spPlugin *plugin = NULL;

    /*spSetDebugLevel(100);*/

    spSetHelpMessage(&help_flag, "FFT benchmark program");
    options = spGetOptions(argc, argv, option, file_label);
    spGetOptionsValue(argc, argv, options);
    spSetDebugLevel(debug_level);

    filename = spGetFile(options);
    
    if (!spStrNone(filename)) {
        spStrCopy(plugin_name, sizeof(plugin_name), filename);
    }

    spDebug(-10, "main", "plugin_name = %s\n", plugin_name);
    
    if (!spStrNone(plugin_name)
        && (plugin = spLoadFFTPlugin(plugin_name)) == NULL) {
        fprintf(stderr, "Can't open FFT plugin: %s\n", plugin_name);
        return 1;
    }

#if 0
    batch_calc_power = SP_TRUE;
#else
    batch_calc_power = SP_FALSE;
#endif
    /* max_nframe = 1000; */
    max_nframe = batch * 128;
    fs = 8000.0;
    freq = 100.0;
    shiftm = 10.0;
    framem = 30.0;
    lengthm = (double)(max_nframe + 1) * shiftm;

    length = (long)spRound(fs * (lengthm / 1000.0));
    shiftl = (long)spRound(fs * (shiftm / 1000.0));
    framel = (long)spRound(fs * (framem / 1000.0));
    spDebug(1, "main", "length = %ld, shiftl = %ld, framel = %ld\n",
            length, shiftl, framel);
    
    vec = xdvalloc(length);

    for (k = 0; k < vec->length; k++) {
        vec->data[k] = sin(2.0 * PI * freq * (double)k / fs);
    }

    order = spMax(spNextPow2(framel), min_order);
    
    for (; order <= max_order; order++) {
        fftl = POW2(order);

        nframe = analysisLoop(vec, shiftl, framel, fftl, max_nframe, &difftime);
        time_us = 1000000.0 * difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (normal FFT of length %ld) = %.2f usec\n", fftl, time_us);

        
        nframe = analysisLoop2(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, max_nframe, NULL, &quick_difftime);
        quick_time_us = 1000000.0 * quick_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (float quick FFT of length %ld) = %.2f usec\n", fftl, quick_time_us);
        quick_rate = (quick_time_us != 0.0 ? time_us / quick_time_us : 0.0);
        fprintf(stderr, "float quick FFT quickness for length %ld = %.2f\n", fftl, quick_rate);
        
        nframe = analysisLoop2(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, max_nframe, NULL, &double_quick_difftime);
        double_quick_time_us = 1000000.0 * double_quick_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (double quick FFT of length %ld) = %.2f usec\n", fftl, double_quick_time_us);
        double_quick_rate = (double_quick_time_us != 0.0 ? time_us / double_quick_time_us : 0.0);
        fprintf(stderr, "double quick FFT quickness for length %ld = %.2f\n", fftl, double_quick_rate);

        
        nframe = analysisLoop3(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, max_nframe, NULL, &quick_real_difftime);
        quick_real_time_us = 1000000.0 * quick_real_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (float quick real FFT of length %ld) = %.2f usec\n", fftl, quick_real_time_us);
        quick_real_rate = (quick_real_time_us != 0.0 ? time_us / quick_real_time_us : 0.0);
        fprintf(stderr, "float quick real FFT quickness for length %ld = %.2f\n", fftl, quick_real_rate);
        
        nframe = analysisLoop3(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, max_nframe, NULL, &double_quick_real_difftime);
        double_quick_real_time_us = 1000000.0 * double_quick_real_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (double quick real FFT of length %ld) = %.2f usec\n", fftl, double_quick_real_time_us);
        double_quick_real_rate = (double_quick_real_time_us != 0.0 ? time_us / double_quick_real_time_us : 0.0);
        fprintf(stderr, "double quick real FFT quickness for length %ld = %.2f\n", fftl, double_quick_real_rate);

        nframe = analysisLoopBatched(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, batch,
                                     max_nframe, NULL, &quick_real_batched_difftime, batch_calc_power);
        quick_real_batched_time_us = 1000000.0 * quick_real_batched_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (float quick real batched FFT of length %ld) = %.2f usec\n",
                fftl, quick_real_batched_time_us);
        quick_real_batched_rate = (quick_real_batched_time_us != 0.0 ? time_us / quick_real_batched_time_us : 0.0);
        fprintf(stderr, "float quick real batched FFT quickness for length %ld = %.2f\n", fftl, quick_real_batched_rate);

        nframe = analysisLoopBatched(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, batch,
                                     max_nframe, NULL, &double_quick_real_batched_difftime, batch_calc_power);
        double_quick_real_batched_time_us = 1000000.0 * double_quick_real_batched_difftime / (double)nframe;
        fprintf(stderr, "FFT calculation time (double quick real batched FFT of length %ld) = %.2f usec\n",
                fftl, double_quick_real_batched_time_us);
        double_quick_real_batched_rate = (double_quick_real_batched_time_us != 0.0 ? time_us / double_quick_real_batched_time_us : 0.0);
        fprintf(stderr, "double quick real batched FFT quickness for length %ld = %.2f\n", fftl, double_quick_real_batched_rate);
        
        analysisLoop2(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, max_nframe, &quick_dist, NULL);
        analysisLoop2(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, max_nframe, &double_quick_dist, NULL);
        analysisLoop3(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, max_nframe, &quick_real_dist, NULL);
        analysisLoop3(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, max_nframe, &double_quick_real_dist, NULL);
        analysisLoopBatched(plugin, SP_FFT_FLOAT_PRECISION, vec, shiftl, framel, fftl, batch,
                            max_nframe, &quick_real_batched_dist, NULL, batch_calc_power);
        analysisLoopBatched(plugin, SP_FFT_DOUBLE_PRECISION, vec, shiftl, framel, fftl, batch,
                            max_nframe, &double_quick_real_batched_dist, NULL, batch_calc_power);
        
        printf("%ld %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n", order, time_us,
               quick_time_us, quick_rate, double_quick_time_us, double_quick_rate,
               quick_real_time_us, quick_real_rate, double_quick_real_time_us, double_quick_real_rate,
               quick_real_batched_time_us, quick_real_batched_rate, double_quick_real_batched_time_us, double_quick_real_batched_rate,
               quick_dist, double_quick_dist, quick_real_dist, double_quick_real_dist,
               quick_real_batched_dist, double_quick_real_batched_dist);
    }
    
    xdvfree(vec);

    if (plugin != NULL) {
        spFreeFFTPlugin(plugin);
    }
    
    return 0;
}
