¿Es posible usar data.table para aplicar una función de dos parámetros rápidamente por grupo en un conjunto de datos? En un conjunto de datos de 1 millón de filas, descubro que llamar a la función simple definida a continuación lleva más de 11 segundos, que es mucho más tiempo de lo que esperaría para algo de esta complejidad.

El siguiente código autónomo describe los elementos esenciales de lo que estoy tratando de hacer:

# generate data frame - 1 million rows
library(data.table)
set.seed(42)
nn = 1e6
daf = data.frame(aa=sample(1:1000, nn, repl=TRUE),
                 bb=sample(1:1000, nn, repl=TRUE),
                 xx=rnorm(nn),
                 yy=rnorm(nn),
                 stringsAsFactors=FALSE)

# myfunc is the function to apply to each group
myfunc = function(xx, yy) {
  if (max(yy)>1) { 
    return(mean(xx))
  } else {
    return(weighted.mean(yy, ifelse(xx>0, 2, 1)))
  }
}

# running the function takes around 11.5 seconds
system.time({
  dt = data.table(daf, key=c("aa","bb"))
  dt = dt[,myfunc(xx, yy), by=c("aa","bb")]
})

head(dt)
# OUTPUT:
#    aa bb          V1
# 1:  1  2 -1.02605645
# 2:  1  3 -0.49318243
# 3:  1  4  0.02165797
# 4:  1  5  0.40811793
# 5:  1  6 -1.00312393
# 6:  1  7  0.14754417

¿Hay alguna manera de reducir significativamente el tiempo para una llamada de función como esta?

Estoy interesado en saber si hay una forma más eficiente de realizar el cálculo anterior sin reescribir completamente la llamada a la función, o si solo se puede acelerar dividiendo la función y reescribiéndola de alguna manera en la sintaxis data.table.

Muchas gracias de antemano por sus respuestas.

0
Timoji 20 oct. 2017 a las 13:41

3 respuestas

La mejor respuesta

He encontrado una manera de obtener una mayor velocidad de 8x, lo que reduce el tiempo a alrededor de 0.2 segundos en mi máquina. Vea abajo. En lugar de calcular la suma (yyw) / sum (w) directamente para cada grupo, lo que lleva mucho tiempo, calculamos las cantidades sum (yyw) y sum (w) para cada grupo, y solo después realizamos la división. ¡Magia!

system.time({
  dt <- data.table(daf, key = c("aa","bb"))
  dt[, w := 1][xx > 0, w := 2]
  dt[, yyw := yy * w]
  res <- dt[, .(maxy = max(yy), 
                meanx = mean(xx), 
                wm2num = sum(yyw), 
                wm2den = sum(w)),
                by = c("aa","bb")]
  res[, wm2 := wm2num/wm2den]            
  res[, V1 := wm2][maxy > 1, V1 := meanx]

  res[, c("maxy", "meanx", "wm2num", "wm2den", "wm2") := NULL]
}) # 0.19

all.equal(res, dtInitial)
# [1] TRUE
1
Timoji 23 oct. 2017 a las 17:01

Tus resultados:

system.time({
  dt = data.table(daf, key = c("aa","bb"))
  dt = dt[,myfunc(xx, yy), by = c("aa","bb")]
}) # 21.25
dtInitial <- copy(dt)

V1: si los valores de NA no le conciernen, puede modificar su función de esta manera:

myfunc2 = function(xx, yy) {
  if (max(yy) > 1) { 
    return(mean(xx))
  } else {
    w <- ifelse(xx > 0, 2, 1)
    return(sum((yy * w)[w != 0])/sum(w))
  }
}

system.time({
  dt = data.table(daf, key = c("aa","bb"))
  dtM = dt[, myfunc2(xx, yy), by = c("aa","bb")]
}) # 6.69  
all.equal(dtM, dtInitial)
# [1] TRUE

V2: Además, puedes hacerlo más rápido así:

system.time({
dt3 <- data.table(daf, key = c("aa","bb"))
dt3[, maxy := max(yy), by = c("aa","bb")]
dt3[, meanx := mean(xx), by = c("aa","bb")]
dt3[, w := ifelse(xx > 0, 2, 1)]
dt3[, wm2 := sum((yy * w)[w != 0])/sum(w), by = c("aa","bb")]
r2 <- dt3[, .(aa, bb, V1 = ifelse(maxy > 1, meanx, wm2))]
r2 <- unique(r2)
}) #2.09 
all.equal(r2, dtInitial)
# [1] TRUE

20 sek vs 2 sek para mí


Actualizar:

O un poco más rápida:

system.time({
  dt3 <- data.table(daf, key = c("aa","bb"))
  dt3[, w := ifelse(xx > 0, 2, 1)]
  dt3[, yyw := yy * w]
  r2 <- dt3[, .(maxy = max(yy),
                meanx = mean(xx),
                wm2 = sum(yyw)/sum(w)), 
            , by = c("aa","bb")]
  r2[, V1 := ifelse(maxy > 1, meanx, wm2)]
  r2[, c("maxy", "meanx", "wm2") := NULL]
}) # 1.51

all.equal(r2, dtInitial)
# [1] TRUE
2
minem 20 oct. 2017 a las 11:55

Otra solución

system.time({
  dat <- data.table(daf, key = c("aa","bb"))
  dat[, xweight := (xx > 0) * 1 + 1]
  result <- dat[, list(MaxY = max(yy), Mean1 = mean(xx), Mean2 = sum(yy*xweight)/sum(xweight)), keyby=c("aa", "bb")]
  result[, FinalMean := ifelse(MaxY > 1, Mean1, Mean2)]
})

   user  system elapsed 
  1.964   0.059   1.348
1
Ben 20 oct. 2017 a las 23:50