CREATE TYPE predict_classification_model_type
       AS (clean_pt_id int, class varchar(20));

--- NOTE: Must be superuser to create PLR functions
CREATE OR REPLACE FUNCTION predict_classification_model
       (filename varchar(256), modelname varchar(256),
        schemaname varchar(256), tablename varchar(256),
	wherecondition varchar(256) )
RETURNS SETOF predict_classification_model_type AS 
$BODY$

pg.thrownotice('Starting predict_classification_model')

mp=capture.output(memory.profile())
for (m in mp) pg.thrownotice(m)

tablename <- pg.quoteident(tablename)
if (nchar(schemaname)>0) {
  schemaname <- pg.quoteident(schemaname)
  tablename <- paste(schemaname,tablename,sep=".")
}

if (nchar(wherecondition) == 0) {
   wherecondition = 'TRUE'
}
wherecondition = pg.quoteliteral(wherecondition)

# Load in file with model.  This can be fairly large.
load(filename)
themodel <- get(modelname)

# Count number of lines in table
s <- paste('SELECT count(*) FROM',tablename,'WHERE',wherecondition)
q <- pg.spi.exec(s)
n <- q$count

# Create empty data.frame for storing result
resdf <- data.frame( clean_pt_id=rep(NA,n), class=rep(NA,n) )

# Open cursor for loading in data a bunch at a time
s <- paste('SELECT * FROM',tablename,'WHERE',wherecondition)
p <- pg.spi.prepare(s)
c <- pg.spi.cursor_open('my_cursor',p);

# Load in chunks and put data in result matrix
M <- 10000 # This being large helps avoid poorly factorizing columns
N <- ceiling(n/M)
if (N>0) {
  for (i in 1:N) {
    d <- pg.spi.factor(pg.spi.cursor_fetch(c,TRUE,as.integer(M)))

    # Fix NA in certain columns
    l <- d[,na.cols]
    l[ is.na(l) ] <- na.replace.val
    d[,na.cols] <- l
    
    # Predict classes
    pr <- as.character(predict(object = themodel, newdata= d, type="response"))

    # Index of first new element
    ii <- (i-1)*M+1

    pg.thrownotice(paste(ii,"of",n))
    
    # New stuff to place in data frame
    resdf[ii:(ii+dim(d)[1]-1),] <- cbind( d$clean_pt_id, pr )
  }
}

pg.spi.cursor_close(c)

mp=capture.output(memory.profile())
for (m in mp) pg.thrownotice(m)

return(resdf)

$BODY$ 
LANGUAGE 'plr';

CREATE OR REPLACE FUNCTION predict_classification_model
       (filename varchar(256), modelname varchar(256),
        schemaname varchar(256), tablename varchar(256) )
RETURNS SETOF predict_classification_model_type AS 
$BODY$
SELECT predict_classification_model($1,$2,$3,$4,'');
$BODY$
LANGUAGE 'SQL';

--- This is how to use the function
-- SELECT predict_classification_model('/tmp/model.randomForest.mtry_4.ntree_1500.RData','themodel','fads_stats','all_stats');
