]> code.communitydata.science - social-media-chapter.git/blob - code/prediction/03-prediction_analysis.R
initial import of material for public archive into git
[social-media-chapter.git] / code / prediction / 03-prediction_analysis.R
1 library(data.table)
2 library(Matrix)
3 library(glmnet)
4 library(xtable)
5 library(methods)
6
7 predict.list <- NULL
8
9 if(!exists("top.ngram.matrix")){
10     load("processed_data/top.ngram.matrix.RData")
11 }
12
13 if(!exists("pred.descrip")){
14     load("paper/data/prediction_descriptives.RData")
15     covars <- pred.descrip$covars
16 }
17
18 top.ngram.matrix <- data.table(top.ngram.matrix)
19 setkey(top.ngram.matrix, eid)
20 covars <- data.table(pred.descrip$covars)
21 setkey(covars,eid)
22
23 # restrict to the overlap of the two datasets
24 covars <- covars[covars$eid %in% top.ngram.matrix$eid,]
25
26 top.ngram.matrix <- top.ngram.matrix[top.ngram.matrix$eid %in%
27                                      covars$eid,]
28
29 # rename the cited column in case it doesn't appear
30 names(covars)[names(covars) == 'cited'] <- 'cited.x'
31
32 # then merge also to facilitate some manipulations below
33 d <- merge(covars, top.ngram.matrix, by="eid", all=FALSE)
34
35 # Note that this duplicates some column names so X gets appended in a
36 # few cases.
37
38 # construct model matrices
39 x.controls <- sparse.model.matrix(cited.x ~ language.x +
40                                     modal_country + month.x,
41                                    data=d)[,-1]
42
43 x.aff <- sparse.model.matrix(cited.x ~ affiliation, data=d)[,-1]
44 x.subj <- sparse.model.matrix(cited.x ~ subject.x, data=d)[,-1]
45 x.venue <- sparse.model.matrix(cited.x ~ source_title, data=d)[,-1]
46
47 x.ngrams <- as.matrix(subset(top.ngram.matrix, select=-eid))
48 x.ngrams <- as(x.ngrams, "sparseMatrix")
49
50 X <- cBind(x.controls, covars$year.x, covars$works.cited)
51 X.aff <- cBind(X, x.aff)
52 X.subj <- cBind(X.aff, x.subj)
53 X.venue <- cBind(X.subj, x.venue)
54 X.terms <- cBind(X.venue, x.ngrams)
55
56 Y <- covars$cited
57
58 ### Hold-back sample for testing model performance later on:
59 set.seed(20160719)
60 holdback.index <- sample(nrow(X), round(nrow(X)*.1))
61
62 X.hold <- X[holdback.index,]
63 X.hold.aff <- X.aff[holdback.index,]
64 X.hold.subj <- X.subj[holdback.index,]
65 X.hold.venue <- X.venue[holdback.index,]
66 X.hold.terms <- X.terms[holdback.index,]
67 Y.hold <- Y[holdback.index]
68
69 X.test <- X[-holdback.index,]
70 X.test.aff <- X.aff[-holdback.index,]
71 X.test.subj <- X.subj[-holdback.index,]
72 X.test.venue <- X.venue[-holdback.index,]
73 X.test.terms <- X.terms[-holdback.index,]
74 Y.test <- Y[-holdback.index]
75
76 ###############  Models and prediction
77
78 set.seed(20160719)
79
80 m.con <- cv.glmnet(X.test, Y.test, alpha=1, family="binomial",
81                     type.measure="class")
82 con.pred = predict(m.con, type="class", s="lambda.min",
83                     newx=X.hold)
84
85 m.aff <- cv.glmnet(X.test.aff, Y.test, alpha=1, family="binomial",
86                     type.measure="class")
87 aff.pred = predict(m.aff, type="class", s="lambda.min",
88                     newx=X.hold.aff)
89
90 m.subj <- cv.glmnet(X.test.subj, Y.test, alpha=1, family="binomial",
91                     type.measure="class")
92 subj.pred = predict(m.subj, type="class", s="lambda.min",
93                     newx=X.hold.subj)
94
95 m.venue <- cv.glmnet(X.test.venue, Y.test, alpha=1, family="binomial",
96                     type.measure="class")
97 venue.pred = predict(m.venue, type="class", s="lambda.min",
98                     newx=X.hold.venue)
99
100 m.terms <- cv.glmnet(X.test.terms, Y.test, alpha=1, family="binomial",
101                     type.measure="class")
102 terms.pred = predict(m.terms, type="class", s="lambda.min",
103                     newx=X.hold.terms)
104
105 ##########
106 # Compare test set predictions against held-back sample:
107
108 pred.df <- data.frame(cbind(con.pred, aff.pred, subj.pred,
109                           venue.pred, terms.pred))
110 names(pred.df) <- c("Controls", "+ Affiliation", "+ Subject", "+ Venue",
111                           "+ Terms") 
112
113 m.list <- list(m.con, m.aff, m.subj, m.venue, m.terms)
114
115 # collect:
116 # df
117 # percent.deviance
118 # nonzero coefficients
119 # prediction error
120
121 gen.m.summ.info <- function(model){
122     df <- round(tail(model$glmnet.fit$df, 1),0)
123     percent.dev <- round(tail(model$glmnet.fit$dev.ratio, 1),2)*100
124     cv.error <- round(tail(model$cvm,1),2)*100
125 #    null.dev <- round(tail(model$glmnet.fit$nulldev),0)
126     out <- c(df, percent.dev, cv.error)
127     return(out)
128 }
129
130 gen.class.err <- function(pred, test){
131     props <- prop.table(table(pred, test))
132     err.sum <- round(sum(props[1,2], props[2,1]),2)*100
133     return(err.sum)
134 }
135
136
137 results.tab <- cbind(names(pred.df),data.frame(matrix(unlist(lapply(m.list,
138                                                gen.m.summ.info)),
139                                  byrow=T, nrow=5)))
140
141 results.tab$class.err <- sapply(pred.df, function(x) gen.class.err(x, 
142                                                                    Y.hold))
143
144 results.tab <- data.frame(lapply(results.tab, as.character))
145
146
147
148 names(results.tab) <- c("Model", "N features", "Deviance (%)",
149                                                "CV error (%)", "Hold-back error (%)")
150
151
152 print(xtable(results.tab,
153              caption=
154                  "Summary of fitted models predicting any citations. The ``Model'' column describes which features were included. The N features column shows the number of features included in the prediction. ``Deviance'' summarizes the goodness of fit as a percentage of the total deviance accounted for by the model. ``CV error'' (cross-validation error) reports the prediction error rates of each model in the cross-validation procedure conducted as part of the parameter estimation process. ``Hold-back error'' shows the prediction error on a random 10 percent subset of the original dataset not included in any of the model estimation procedures.",
155              label='tab:predict_models', align='llrrrr'),
156              include.rownames=FALSE)
157
158 # Store the results:
159 predict.list$results.tab <- results.tab
160
161
162
163
164 ############# Generate most salient coefficients
165 nz.coefs <- data.frame(                       coef =
166                            colnames(X.test.terms)[which(
167                                        coef(m.terms, s="lambda.min")
168                                        != 0)],
169                        type = "term",
170                        beta =
171                            coef(m.terms,
172                                 s="lambda.min")[which(coef(m.terms,
173                                                            s="lambda.min")
174                                                       != 0)])
175
176 nz.coefs$coef <- as.character(nz.coefs$coef)
177 nz.coefs$type <- as.character(nz.coefs$type)
178 nz.coefs <- nz.coefs[order(-abs(nz.coefs$beta)),]
179
180 # comparison:
181
182 #nz.coefs$type <- "terms"
183 nz.coefs$type[grepl("(Intercept)", nz.coefs$coef)] <- NA
184 nz.coefs$type[grepl("source_title", nz.coefs$coef)] <- "venue"
185 nz.coefs$type[grepl("subject.x", nz.coefs$coef)] <- "subject"
186 nz.coefs$type[grepl("affiliation", nz.coefs$coef)] <- "affiliation"
187 nz.coefs$type[grepl("month.x", nz.coefs$coef)] <- "month"
188 nz.coefs$type[grepl("modal_country", nz.coefs$coef)] <- "country"
189 nz.coefs$type[grepl("language", nz.coefs$coef)] <- "language"
190 nz.coefs$type[grepl("^20[0-9]{2}$", nz.coefs$coef)] <- "year"
191
192
193 # cleanup 
194 nz.coefs$coef <- gsub("source_title", "", nz.coefs$coef)
195 nz.coefs$coef <- gsub("subject.x", "", nz.coefs$coef)
196 nz.coefs$coef <- gsub("affiliation","", nz.coefs$coef)
197 nz.coefs$beta <- round(nz.coefs$beta, 3)
198 names(nz.coefs) <- c("Feature", "Type", "Coefficient")
199
200 predict.list$nz.coefs <- nz.coefs
201
202 # table for all
203 round(prop.table(table(nz.coefs$Type))*100, 2)
204
205 # for top subsets
206 round(prop.table(table(nz.coefs$Type[1:700]))*100, 2)
207 round(prop.table(table(nz.coefs$Type[1:200]))*100, 2)
208 round(prop.table(table(nz.coefs$Type[1:100]))*100, 2)
209
210 print(xtable(
211     as.matrix(head(nz.coefs, 10)),
212     label='tab:nzcoefs',
213     caption='Feature, variable type, and beta value for top 100 non-zero coefficients estimated by the best fitting model with all features included.',
214     align='lllr'
215 ), include.rownames=FALSE)
216
217
218 # output
219 save(predict.list, file="paper/data/prediction.RData")
220
221

Community Data Science Collective || Want to submit a patch?