library(matrixStats) # for numerically stable logsumexps

options(amelia.parallel="no",
        amelia.ncpus=1)
library(Amelia)
source("measerr_methods.R")
source("pl_methods.R")

run_simulation <- function(df, result, outcome_formula = y ~ x + z, proxy_formula = w_pred ~ x, coder_formulas=list(x.obs.1~x, x.obs.0~x), truth_formula = x ~ z){

    accuracy <- df[,mean(w_pred==x)]
    result <- append(result, list(accuracy=accuracy))

    (model.true <- lm(y ~ x + z, data=df))
    true.ci.Bxy <- confint(model.true)['x',]
    true.ci.Bzy <- confint(model.true)['z',]

    result <- append(result, list(Bxy.est.true=coef(model.true)['x'],
                                  Bzy.est.true=coef(model.true)['z'],
                                  Bxy.ci.upper.true = true.ci.Bxy[2],
                                  Bxy.ci.lower.true = true.ci.Bxy[1],
                                  Bzy.ci.upper.true = true.ci.Bzy[2],
                                  Bzy.ci.lower.true = true.ci.Bzy[1]))





    loa0.feasible <- lm(y ~ x.obs.0 + z, data = df[!(is.na(x.obs.1))])

    loa0.ci.Bxy <- confint(loa0.feasible)['x.obs.0',]
    loa0.ci.Bzy <- confint(loa0.feasible)['z',]

    result <- append(result, list(Bxy.est.loa0.feasible=coef(loa0.feasible)['x.obs.0'],
                                  Bzy.est.loa0.feasible=coef(loa0.feasible)['z'],
                                  Bxy.ci.upper.loa0.feasible = loa0.ci.Bxy[2],
                                  Bxy.ci.lower.loa0.feasible = loa0.ci.Bxy[1],
                                  Bzy.ci.upper.loa0.feasible = loa0.ci.Bzy[2],
                                  Bzy.ci.lower.loa0.feasible = loa0.ci.Bzy[1]))
    print("fitting loa0 model")

    df.loa0.mle <- copy(df)
    df.loa0.mle[,x:=x.obs.0]
    loa0.mle <- measerr_mle(df.loa0.mle, outcome_formula=outcome_formula, proxy_formula=proxy_formula, truth_formula=truth_formula)
    fisher.info <- solve(loa0.mle$hessian)
    coef <- loa0.mle$par
    ci.upper <- coef + sqrt(diag(fisher.info)) * 1.96
    ci.lower <- coef - sqrt(diag(fisher.info)) * 1.96

    result <- append(result, list(Bxy.est.loa0.mle=coef['x'],
                                  Bzy.est.loa0.mle=coef['z'],
                                  Bxy.ci.upper.loa0.mle = ci.upper['x'],
                                  Bxy.ci.lower.loa0.mle = ci.lower['x'],
                                  Bzy.ci.upper.loa0.mle = ci.upper['z'],
                                  Bzy.ci.lower.loa0.mle = ci.upper['z']))



    loco.feasible <- lm(y ~ x.obs.1 + z, data = df[(!is.na(x.obs.1)) & (x.obs.1 == x.obs.0)])


    loco.feasible.ci.Bxy <- confint(loco.feasible)['x.obs.1',]
    loco.feasible.ci.Bzy <- confint(loco.feasible)['z',]

    result <- append(result, list(Bxy.est.loco.feasible=coef(loco.feasible)['x.obs.1'],
                                  Bzy.est.loco.feasible=coef(loco.feasible)['z'],
                                  Bxy.ci.upper.loco.feasible = loco.feasible.ci.Bxy[2],
                                  Bxy.ci.lower.loco.feasible = loco.feasible.ci.Bxy[1],
                                  Bzy.ci.upper.loco.feasible = loco.feasible.ci.Bzy[2],
                                  Bzy.ci.lower.loco.feasible = loco.feasible.ci.Bzy[1]))


    (model.naive <- lm(y~w_pred+z, data=df))
    
    naive.ci.Bxy <- confint(model.naive)['w_pred',]
    naive.ci.Bzy <- confint(model.naive)['z',]

    result <- append(result, list(Bxy.est.naive=coef(model.naive)['w_pred'],
                                  Bzy.est.naive=coef(model.naive)['z'],
                                  Bxy.ci.upper.naive = naive.ci.Bxy[2],
                                  Bxy.ci.lower.naive = naive.ci.Bxy[1],
                                  Bzy.ci.upper.naive = naive.ci.Bzy[2],
                                  Bzy.ci.lower.naive = naive.ci.Bzy[1]))
                                  
    print("fitting loco model")

    df.loco.mle <- copy(df)
    df.loco.mle[,x.obs:=NA]
    df.loco.mle[(x.obs.0)==(x.obs.1),x.obs:=x.obs.0]
    df.loco.mle[,x.true:=x]
    df.loco.mle[,x:=x.obs]
    print(df.loco.mle[!is.na(x.obs.1),mean(x.true==x,na.rm=TRUE)])
    loco.accuracy <- df.loco.mle[(!is.na(x.obs.1)) & (x.obs.1 == x.obs.0),mean(x.obs.1 == x.true)]    
    loco.mle <- measerr_mle(df.loco.mle, outcome_formula=outcome_formula, proxy_formula=proxy_formula, truth_formula=truth_formula)
    fisher.info <- solve(loco.mle$hessian)
    coef <- loco.mle$par
    ci.upper <- coef + sqrt(diag(fisher.info)) * 1.96
    ci.lower <- coef - sqrt(diag(fisher.info)) * 1.96

    result <- append(result, list(loco.accuracy=loco.accuracy,
                                  Bxy.est.loco.mle=coef['x'],
                                  Bzy.est.loco.mle=coef['z'],
                                  Bxy.ci.upper.loco.mle = ci.upper['x'],
                                  Bxy.ci.lower.loco.mle = ci.lower['x'],
                                  Bzy.ci.upper.loco.mle = ci.upper['z'],
                                  Bzy.ci.lower.loco.mle = ci.lower['z']))

    df.double.proxy.mle <- copy(df)
    df.double.proxy.mle[,x.obs:=NA]
    print("fitting double proxy model")

    double.proxy.mle <- measerr_irr_mle(df.double.proxy.mle, outcome_formula=outcome_formula, proxy_formula=proxy_formula, coder_formulas=coder_formulas[1], truth_formula=truth_formula)
    fisher.info <- solve(double.proxy.mle$hessian)
    coef <- double.proxy.mle$par
    ci.upper <- coef + sqrt(diag(fisher.info)) * 1.96
    ci.lower <- coef - sqrt(diag(fisher.info)) * 1.96

    result <- append(result, list(
                                  Bxy.est.double.proxy=coef['x'],
                                  Bzy.est.double.proxy=coef['z'],
                                  Bxy.ci.upper.double.proxy = ci.upper['x'],
                                  Bxy.ci.lower.double.proxy = ci.lower['x'],
                                  Bzy.ci.upper.double.proxy = ci.upper['z'],
                                  Bzy.ci.lower.double.proxy = ci.lower['z']))

    df.triple.proxy.mle <- copy(df)
    df.triple.proxy.mle[,x.obs:=NA]

    print("fitting triple proxy model")
    triple.proxy.mle <- measerr_irr_mle(df.triple.proxy.mle, outcome_formula=outcome_formula, proxy_formula=proxy_formula, coder_formulas=coder_formulas, truth_formula=truth_formula)
    fisher.info <- solve(triple.proxy.mle$hessian)
    coef <- triple.proxy.mle$par
    ci.upper <- coef + sqrt(diag(fisher.info)) * 1.96
    ci.lower <- coef - sqrt(diag(fisher.info)) * 1.96

    result <- append(result, list(
                                  Bxy.est.triple.proxy=coef['x'],
                                  Bzy.est.triple.proxy=coef['z'],
                                  Bxy.ci.upper.triple.proxy = ci.upper['x'],
                                  Bxy.ci.lower.triple.proxy = ci.lower['x'],
                                  Bzy.ci.upper.triple.proxy = ci.upper['z'],
                                  Bzy.ci.lower.triple.proxy = ci.lower['z']))
    tryCatch({
    amelia.out.k <- amelia(df.loco.mle, m=200, p2s=0, idvars=c('x.true','w','x.obs.1','x.obs.0','x'))
    mod.amelia.k <- zelig(y~x.obs+z, model='ls', data=amelia.out.k$imputations, cite=FALSE)
    (coefse <- combine_coef_se(mod.amelia.k, messages=FALSE))

    est.x.mi <- coefse['x.obs','Estimate']
    est.x.se <- coefse['x.obs','Std.Error']
    result <- append(result,
                     list(Bxy.est.amelia.full = est.x.mi,
                          Bxy.ci.upper.amelia.full = est.x.mi + 1.96 * est.x.se,
                          Bxy.ci.lower.amelia.full = est.x.mi - 1.96 * est.x.se
                          ))

    est.z.mi <- coefse['z','Estimate']
    est.z.se <- coefse['z','Std.Error']

    result <- append(result,
                     list(Bzy.est.amelia.full = est.z.mi,
                          Bzy.ci.upper.amelia.full = est.z.mi + 1.96 * est.z.se,
                          Bzy.ci.lower.amelia.full = est.z.mi - 1.96 * est.z.se
                          ))

    },
    error = function(e){
        message("An error occurred:\n",e)
        result$error <-paste0(result$error,'\n', e)
    }
    )

    tryCatch({

        mod.zhang.lik <- zhang.mle.iv(df.loco.mle)
        coef <- coef(mod.zhang.lik)
        ci <- confint(mod.zhang.lik,method='quad')
        result <- append(result,
                         list(Bxy.est.zhang = coef['Bxy'],
                              Bxy.ci.upper.zhang = ci['Bxy','97.5 %'],
                              Bxy.ci.lower.zhang = ci['Bxy','2.5 %'],
                              Bzy.est.zhang = coef['Bzy'],
                              Bzy.ci.upper.zhang = ci['Bzy','97.5 %'],
                              Bzy.ci.lower.zhang = ci['Bzy','2.5 %']))
    },

    error = function(e){
        message("An error occurred:\n",e)
        result$error <- paste0(result$error,'\n', e)
    })

    df <- df.loco.mle
    N <- nrow(df)
    m <- nrow(df[!is.na(x.obs)])
    p <- v <- train <- rep(0,N)
    M <- m
    p[(M+1):(N)] <- 1
    v[1:(M)] <- 1
    df <- df[order(x.obs)]
    y <- df[,y]
    x <- df[,x.obs]
    z <- df[,z]
    w <- df[,w_pred]
    # gmm gets pretty close
    (gmm.res <- predicted_covariates(y, x, z, w, v, train, p, max_iter=100, verbose=TRUE))

    result <- append(result,
                     list(Bxy.est.gmm = gmm.res$beta[1,1],
                          Bxy.ci.upper.gmm = gmm.res$confint[1,2],
                          Bxy.ci.lower.gmm = gmm.res$confint[1,1],
                          gmm.ER_pval = gmm.res$ER_pval
                          ))

    result <- append(result,
                     list(Bzy.est.gmm = gmm.res$beta[2,1],
                          Bzy.ci.upper.gmm = gmm.res$confint[2,2],
                          Bzy.ci.lower.gmm = gmm.res$confint[2,1]))



    return(result)

}