Free time & chunking modeling
  • Versions
    • Latest
    • v0.1
  • Source Code
  • Report a Bug
  1. Development notes
  2. Extra primacy parameter
  • Version 0.1
  •  
  • About
  • Development notes
    • Notes
    • 2024-05-16 Meeting Notes
    • Extra primacy parameter
  • Notebooks
    • Data
      • View the data structure
      • Exploratory data analysis
      • Subject-level data
    • Model 1: Original model
      • Main results
      • Parameter identifiability
      • Sensitivity to tau
      • Experiment 3
      • Exploring model predictions
    • Model 2: Include encoding time
      • Main results
    • Model 3: Non-linear recovery
      • Explore model predictions
      • Basic fits
      • Bootstrapping data and fits for parameter uncertainty estimation
      • Extra primacy parameter
      • Linear recovery as a random variable
  • Function reference
    • Aggregate Data
    • Perform bootstrapped estimation
    • Calculate the deviance of a model
    • Get data object from a file
    • Generate a bootstrapped dataset
    • get_data function
    • Inverse Logit Transformation
    • Logit Transformation
    • Calculate the overall deviance
    • Plot Bootstrap Results
    • Plot Linear RV Recovery
    • Preprocesses the data
    • Execute an expression and save the result to a file or load the result from a file if it already exists.
    • Serial Recall Model

On this page

  • Simulate full trial sequences
    • With recovery during entire recall period
    • No recovery during recall, except during empty screens
  1. Development notes
  2. Extra primacy parameter

Extra primacy parameter

  • Show All Code
  • Hide All Code

  • View Source

TODO: Clean-up this notebook

Code
library(tidyverse)
library(targets)
tar_source()
tar_load(c(exp3_data_agg, exp3_data))

Define new model likelihood

One example fit:

Code
set.seed(25)
fit <- estimate_model(c(start_fun(), prim_prop = rbeta(1, 5, 2)), exp3_data_agg, version = 2, exclude_sp1 = T)
kableExtra::kable(optimfit_to_df(fit))
prop prop_ltm rate gain tau prim_prop deviance convergence
0.129707 0.5184667 0.1597853 42.70181 0.089936 0.9305835 760.8202 0
Code
exp3_data_agg$pred <- predict(fit, exp3_data_agg, group_by = c("chunk", "gap"))

exp3_data_agg |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)

Run the estimation 100 times with different starting values:

Code
set.seed(18)
fits <- run_or_load(
  {
    res <- replicate(
      n = 100,
      expr = estimate_model(
        c(start_fun(), prim_prop = rbeta(1, 5, 2)),
        exp3_data_agg,
        version = 2, exclude_sp1 = T
      ),
      simplify = FALSE
    )
    res <- do.call(rbind, lapply(res, optimfit_to_df))
    arrange(res, deviance)
  },
  file = "output/exp3_fits_100_prop_primacy.rds"
)

Refine the fits with a second pass starting from the best fit:

Code
fits_refined <- run_or_load(
  {
    res <- apply(fits, 1, function(x) {
      fit <- estimate_model(x, exp3_data_agg, version = 2, exclude_sp1 = T)
      optimfit_to_df(fit)
    })
    do.call(rbind, res)
  },
  file = "output/exp3_fits_100_prop_primacy_refined.rds"
)
fits_refined$deviance <- 2 * fits_refined$deviance

Print fits sorted by deviance:

Code
kableExtra::kable(arrange(fits_refined, deviance))
prop prop_ltm rate gain tau prim_prop deviance convergence
0.0830018 0.5354440 0.1521677 100.00000 0.0664134 0.9592183 1520.717 0
0.0830345 0.5354821 0.1522194 99.92881 0.0664355 0.9592133 1520.717 0
0.0827561 0.5303713 0.1540288 99.99304 0.0661689 0.9585271 1520.720 0
0.0833021 0.5326726 0.1534809 98.99042 0.0665430 0.9585757 1520.723 0
0.0836369 0.5354833 0.1524269 98.53729 0.0668022 0.9588923 1520.724 0
0.0838078 0.5349846 0.1527897 98.13969 0.0669035 0.9587414 1520.726 0
0.0839589 0.5352356 0.1523192 97.77404 0.0669837 0.9586211 1520.728 0
0.0840183 0.5344278 0.1528843 97.58774 0.0670162 0.9585227 1520.729 0
0.0840788 0.5337896 0.1528134 97.48065 0.0670620 0.9585533 1520.729 0
0.0843323 0.5352119 0.1524255 96.99083 0.0672207 0.9584937 1520.731 0
0.0844863 0.5384292 0.1511989 96.96959 0.0673687 0.9588990 1520.732 0
0.0841521 0.5333316 0.1531871 97.06339 0.0670524 0.9581341 1520.733 0
0.0850234 0.5387903 0.1513496 95.74477 0.0676811 0.9585429 1520.738 0
0.0854952 0.5374167 0.1511926 94.65133 0.0679400 0.9581097 1520.744 0
0.0853425 0.5370373 0.1515787 94.77365 0.0678094 0.9579259 1520.744 0
0.0854451 0.5350333 0.1525429 94.50061 0.0678727 0.9577814 1520.745 0
0.0857269 0.5368538 0.1515001 94.18233 0.0680980 0.9581038 1520.746 0
0.0872424 0.5390619 0.1510209 91.40748 0.0690687 0.9577849 1520.763 0
0.0878679 0.5354037 0.1524971 89.87321 0.0693818 0.9569372 1520.771 0
0.0880266 0.5353644 0.1539354 89.55011 0.0694815 0.9567853 1520.778 0
0.0884297 0.5360268 0.1513224 88.72862 0.0696856 0.9565830 1520.779 0
0.0878443 0.5251040 0.1565616 89.02264 0.0691813 0.9553297 1520.786 0
0.0892337 0.5399017 0.1496270 87.49761 0.0702117 0.9566977 1520.788 0
0.0896874 0.5467911 0.1466649 86.80853 0.0704719 0.9567592 1520.802 0
0.0915173 0.5403781 0.1506760 83.44070 0.0715516 0.9554777 1520.814 0
0.0915863 0.5467830 0.1481750 83.97125 0.0717513 0.9567280 1520.815 0
0.0907884 0.5516159 0.1462604 85.63544 0.0713531 0.9578251 1520.816 0
0.0904884 0.5308324 0.1531446 84.52523 0.0708115 0.9548965 1520.819 0
0.0921162 0.5379638 0.1511176 82.15251 0.0718255 0.9546897 1520.824 0
0.0934299 0.5461201 0.1483120 80.77434 0.0727898 0.9556430 1520.837 0
0.0936599 0.5373178 0.1509119 79.91900 0.0728257 0.9546394 1520.847 0
0.0934048 0.5373752 0.1522977 79.64859 0.0724489 0.9532476 1520.856 0
0.0954108 0.5338564 0.1530252 76.62310 0.0736285 0.9523789 1520.874 0
0.0964956 0.5497604 0.1470154 76.30653 0.0746150 0.9547932 1520.879 0
0.0917041 0.5584003 0.1470930 83.85527 0.0717872 0.9566873 1520.882 0
0.0973273 0.5302367 0.1539601 73.54802 0.0746042 0.9507225 1520.911 0
0.0987585 0.5460551 0.1491697 72.64034 0.0757628 0.9528200 1520.913 0
0.0987959 0.5381743 0.1510308 71.89543 0.0755197 0.9509666 1520.921 0
0.1007429 0.5473097 0.1478690 70.29082 0.0769701 0.9526443 1520.941 0
0.1031380 0.5365476 0.1515734 66.21486 0.0778125 0.9484016 1520.993 0
0.1041868 0.5425277 0.1497028 65.23047 0.0784592 0.9486410 1521.005 0
0.1041107 0.5585138 0.1426450 66.85880 0.0790617 0.9531637 1521.005 0
0.1052323 0.5522870 0.1466097 65.03940 0.0794905 0.9512479 1521.012 0
0.1032888 0.5270420 0.1553102 65.66822 0.0777932 0.9473241 1521.014 0
0.1050829 0.5429193 0.1493026 64.45731 0.0790327 0.9487062 1521.016 0
0.1055747 0.5453514 0.1495319 64.38901 0.0795725 0.9500933 1521.017 0
0.1057176 0.5363009 0.1519290 63.21499 0.0791567 0.9469873 1521.037 0
0.1056761 0.5328284 0.1528260 63.19574 0.0791423 0.9469534 1521.040 0
0.1095031 0.5597043 0.1427385 60.18882 0.0814623 0.9482292 1521.102 0
0.1038245 0.5178414 0.1606086 64.11113 0.0776493 0.9440227 1521.108 0
0.1112165 0.5405131 0.1513658 57.83141 0.0821252 0.9451479 1521.121 0
0.1139487 0.5665163 0.1408282 56.79748 0.0842963 0.9495305 1521.163 0
0.1150464 0.5620415 0.1424970 55.49423 0.0846745 0.9478959 1521.174 0
0.1104956 0.5844357 0.1330230 61.23827 0.0830523 0.9550584 1521.176 0
0.1150277 0.5707915 0.1394673 56.13136 0.0850230 0.9502024 1521.195 0
0.1168787 0.5563987 0.1475113 53.57138 0.0854354 0.9456710 1521.224 0
0.1170302 0.5467398 0.1477255 52.79237 0.0850381 0.9430341 1521.225 0
0.1190151 0.5745878 0.1375233 52.76130 0.0870007 0.9487474 1521.265 0
0.1180843 0.5374202 0.1543369 51.51058 0.0853142 0.9404624 1521.282 0
0.1206378 0.5369294 0.1526183 49.75313 0.0867212 0.9404126 1521.311 0
0.1222813 0.5450775 0.1484505 48.58462 0.0874118 0.9398335 1521.338 0
0.1219608 0.5950076 0.1293583 51.47250 0.0890850 0.9520846 1521.387 0
0.1258652 0.5543236 0.1449665 46.47660 0.0893696 0.9402240 1521.393 0
0.1261358 0.5512622 0.1469249 46.15517 0.0893986 0.9393450 1521.405 0
0.1229203 0.6004439 0.1274795 51.09406 0.0898015 0.9532568 1521.437 0
0.1253455 0.5347543 0.1524909 45.93500 0.0884188 0.9358299 1521.453 0
0.1096435 0.6245279 0.1157842 65.14104 0.0838056 0.9644263 1521.465 0
0.1291902 0.5698808 0.1404231 44.52248 0.0909739 0.9401136 1521.507 0
0.1316542 0.5861278 0.1329112 44.07969 0.0930474 0.9448045 1521.528 0
0.1320240 0.5487744 0.1460666 42.47430 0.0920422 0.9368930 1521.552 0
0.1407647 0.5726243 0.1378135 38.67907 0.0966723 0.9386590 1521.727 0
0.1410802 0.5921934 0.1299198 39.08275 0.0972915 0.9421512 1521.746 0
0.1416289 0.5584553 0.1449145 37.55499 0.0961602 0.9331385 1521.772 0
0.1442493 0.5817889 0.1357689 37.16069 0.0981666 0.9379234 1521.805 0
0.1485820 0.5506684 0.1476578 34.44050 0.0989464 0.9297136 1521.974 0
0.1335032 0.4896621 0.1771724 40.02027 0.0912584 0.9248692 1522.027 0
0.1508329 0.5659537 0.1397505 33.56548 0.0994088 0.9283791 1522.043 0
0.1464991 0.6365233 0.1119736 38.25589 0.1015555 0.9525763 1522.066 0
0.1570901 0.5720074 0.1369226 31.39721 0.1020279 0.9277262 1522.192 0
0.1578781 0.5796311 0.1354878 31.28607 0.1025839 0.9289186 1522.195 0
0.1557528 0.5623220 0.1426612 31.36372 0.1005411 0.9230903 1522.254 0
0.1615952 0.5661500 0.1440328 29.91051 0.1040553 0.9264082 1522.330 0
0.1625162 0.5859924 0.1360934 29.80458 0.1043471 0.9275039 1522.355 0
0.1643793 0.6202930 0.1195451 30.34787 0.1072577 0.9396434 1522.388 0
0.1628606 0.6659198 0.1026569 32.71971 0.1100000 0.9576486 1522.606 0
0.1673065 0.5570623 0.1476521 27.42337 0.1035436 0.9131497 1522.795 0
0.1494665 0.7064391 0.0870326 40.84013 0.1073822 0.9772426 1522.822 0
0.1757971 0.5737036 0.1355153 25.60590 0.1074592 0.9172798 1522.859 0
0.1237810 0.7275594 0.0803559 60.88937 0.0958362 0.9892582 1522.916 0
0.1478485 0.7136234 0.0842819 42.33303 0.1072880 0.9809562 1522.918 0
0.1508772 0.7128888 0.0843160 40.69104 0.1086132 0.9801147 1522.946 0
0.1641892 0.5166694 0.1581206 27.52562 0.1009010 0.9058324 1522.977 0
0.1454702 0.7272316 0.0801023 44.59122 0.1069657 0.9859114 1523.054 0
0.1560018 0.7256169 0.0812223 39.04160 0.1118570 0.9840110 1523.160 0
0.1008613 0.7540632 0.0701410 95.94044 0.0832805 0.9999999 1523.350 0
0.1500814 0.7543605 0.0712736 44.53389 0.1115246 0.9976890 1523.548 0
0.2012512 0.6942274 0.0925222 22.99875 0.1245127 0.9570987 1523.700 0
0.1954065 0.5381045 0.1690204 21.11082 0.1135602 0.9050614 1524.333 0
0.2222245 0.7091319 0.0886436 19.79598 0.1324736 0.9621327 1524.418 0
0.2420613 0.7426193 0.0755768 17.80195 0.1411300 0.9779758 1525.069 0

Not as much variance as I thought there would be. Let’s limit to the fits within 2 deviance units of the best fit:

Here is the distribution of parameter values that are within 2 deviance units of the best fit:

Code
best_fits <- filter(fits_refined, deviance < min(deviance) + 2)

best_fits |>
  select(prop:prim_prop) |>
  pivot_longer(cols = everything(), names_to = "parameter", values_to = "value") |>
  ggplot(aes(x = value)) +
  geom_histogram(bins = 30) +
  facet_wrap(~parameter, scales = "free") +
  theme_pub()

Here is the plot of the best fitting model:

Code
best_fit <- estimate_model(unlist(best_fits[1, 1:6]), exp3_data_agg, version = 2, exclude_sp1 = T)
exp3_data_agg$pred <- predict(best_fit, exp3_data_agg, group_by = c("chunk", "gap"))

exp3_data_agg |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)
`geom_smooth()` using formula = 'y ~ x'

and the parameter estimates:

Code
kableExtra::kable(t(as.data.frame(best_fit$par)))
prop prop_ltm rate gain tau prim_prop
x 0.0827601 0.5305627 0.1540501 99.99801 0.0661733 0.958538

with this rate, the resource recovery over time looks like this:

Code
resources <- function(R, rate, time, r_max = 1) {
  (r_max - R) * (1 - exp(-rate * time))
}

time <- seq(0, 20, 0.1)
tibble(time = time, resources = resources(0, best_fits[1, ]$rate, time)) |>
  ggplot(aes(x = time, y = resources)) +
  geom_line() +
  theme_pub()

with 50% of the resource recovered after 4.5 seconds

and 25% of the resource recovered after 1.87 seconds

Simulate full trial sequences

With recovery during entire recall period

Calculate the total recall period duration on each trial:

Code
total_rts <- exp3_data |>
  group_by(id, trial) |>
  summarize(total_recall_duration = sum(rt) + 9 * 0.2 + 1)
`summarise()` has grouped output by 'id'. You can override using the `.groups` argument.
Code
hist(total_rts$total_recall_duration, breaks = 30, col = "grey", border = "white", xlab = "Total recall period duration (s.)", main = "")

Create the trial structure for the simulation:

Code
exp3_trial_str <- run_or_load({
  exp3_data |>
  group_by(id, trial) |>
  do({
    aggregate_data(.)
  })
}, file = "output/exp3_trial_structure.rds")

exp3_trial_str <- left_join(exp3_trial_str, total_rts) |> 
  mutate(ISI = ifelse(itemtype == "SP7-9", total_recall_duration, ISI),
         ser_pos = case_when(
          itemtype == "SP1-3" ~ 1,
          itemtype == "SP4-6" ~ 2,
          itemtype == "SP7-9" ~ 3
         ))

Run the model without resetting the resource at the start of each trial:

Code
serial_recall_full <- function(
    setsize, ISI = rep(0.5, setsize), item_in_ltm = rep(TRUE, setsize), ser_pos = 1:setsize, 
    prop = 0.2, prop_ltm = 0.5, tau = 0.15, gain = 25, rate = 0.1, prim_prop = 1,
    r_max = 1, lambda = 1, growth = "linear") {
  R <- r_max
  p_recall <- vector("numeric", length = setsize)
  prop_ltm <- ifelse(item_in_ltm, prop_ltm, 1)

  for (item in 1:setsize) {
    # strength of the item and recall probability
    strength <- (prop * R)^lambda * prim_prop^(ser_pos[item] - 1)
    p_recall[item] <- 1 / (1 + exp(-(strength - tau) * gain))

    # amount of resources consumed by the item
    r_cost <- strength^(1 / lambda) * prop_ltm[item]
    R <- R - r_cost

    # recover resources
    R <- switch(growth,
      "linear" = min(r_max, R + rate * ISI[item]),
      "asy" = R + (r_max - R) * (1 - exp(-rate * ISI[item]))
    )
  }

  p_recall
}

subj1 <- exp3_trial_str |>
  filter(id == 44125)

Here it is for one example subject:

Code
subj1$pred <- serial_recall_full(
  setsize = nrow(subj1),
  ISI = subj1$ISI,
  item_in_ltm = subj1$item_in_ltm,
  ser_pos = subj1$ser_pos,
  prop = best_fit$par["prop"],
  prop_ltm = best_fit$par["prop_ltm"],
  tau = best_fit$par["tau"],
  gain = best_fit$par["gain"],
  rate = best_fit$par["rate"],
  prim_prop = best_fit$par["prim_prop"],
  growth = "asy"
)

subj1 |>
  ungroup() |>
  mutate(absolute_position = 1:n()) |>
  ggplot(aes(x = absolute_position, y = pred)) +
  geom_point() +
  geom_line() +
  theme_pub()

Same, but collapsed over serial position (one value per trial):

Code
subj1 |>
  ungroup() |>
  ggplot(aes(x = trial, y = pred)) +
  stat_summary(geom = "point") +
  geom_smooth() +
  theme_pub()

Now simulate the full dataset:

Code
sim_no_reset_full <- function(data, best_fit) {
  data |>
    group_by(id) |>
    mutate(pred = serial_recall_full(
      setsize = n(),
      ISI = ISI,
      item_in_ltm = item_in_ltm,
      ser_pos = ser_pos,
      prop = best_fit$par["prop"],
      prop_ltm = best_fit$par["prop_ltm"],
      tau = best_fit$par["tau"],
      gain = best_fit$par["gain"],
      rate = best_fit$par["rate"],
      prim_prop = best_fit$par["prim_prop"],
      growth = "asy"
    ))
}

full_sim <- sim_no_reset_full(exp3_trial_str, best_fit)

ggplot(full_sim, aes(x = trial, y = pred)) +
  stat_summary() +
  geom_smooth() +
  theme_pub()

Yes, the model predicts worsening performance over trials, but the effect is miniscule.

Plot the original predictions recomputed with the full trial-by-trial model:

Code
full_sim |>
  group_by(chunk, gap, itemtype) |>
  summarize(pred = mean(pred),
            p_correct = mean(p_correct)) |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)

No recovery during recall, except during empty screens

Code
exp3_trial_str <- run_or_load(
  {
    exp3_data |>
      group_by(id, trial) |>
      do({
        aggregate_data(.)
      })
  },
  file = "output/exp3_trial_structure.rds"
)


exp3_trial_str <- exp3_trial_str |>
  mutate(
    ISI = ifelse(itemtype == "SP7-9", 9 * 0.2 + 1, ISI),
    ser_pos = case_when(
      itemtype == "SP1-3" ~ 1,
      itemtype == "SP4-6" ~ 2,
      itemtype == "SP7-9" ~ 3
    )
  )

full_sim_v2 <- sim_no_reset_full(exp3_trial_str, best_fit)

ggplot(full_sim_v2, aes(x = trial, y = pred)) +
  stat_summary() +
  theme_pub()
No summary function supplied, defaulting to `mean_se()`

Code
filter(full_sim_v2, id == 44125) |>
  mutate(absolute_position = 1:n()) |>
  ggplot(aes(x = absolute_position, y = pred)) +
  geom_point() +
  geom_line() +
  theme_pub()

Code
full_sim_v2 |>
  group_by(chunk, gap, itemtype) |>
  summarize(pred = mean(pred),
            p_correct = mean(p_correct)) |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)

Back to top
Model with sqrt scaled strength
Playground
Source Code
---
title: "Extra primacy parameter"
format: html
---

TODO: Clean-up this notebook

```{r}
#| label: init
#| message: false
library(tidyverse)
library(targets)
tar_source()
tar_load(c(exp3_data_agg, exp3_data))
```

Define new model likelihood


```{r, include = FALSE}
serial_recall_v2 <- function(
    setsize, ISI = rep(0.5, setsize), item_in_ltm = rep(TRUE, setsize),
    prop = 0.2, prop_ltm = 0.5, tau = 0.15, gain = 25, rate = 0.1, prim_prop = 1,
    r_max = 1, lambda = 1, growth = "linear") {
  w1 <- 1 - (1 - prop_ltm) * item_in_ltm[1]
  psq <- prop^2
  tmp <- w1 * exp(-rate * ISI[1]) * psq
  tmp2 <- exp(-rate * ISI[2])
  s <- c(
    prop,
    (prop - tmp) * prim_prop,
    (prop - tmp * (1 - prop) * tmp2 - psq * tmp2) * prim_prop^2
  )

  p_recall <- 1 / (1 + exp(-gain * (s - tau)))
  p_recall
}

calcdev <- function(params, dat, exclude_sp1 = FALSE, version = 2, ...) {
  if (version == 1) {
    pred <- serial_recall(
      setsize = 3,
      ISI = dat$ISI,
      item_in_ltm = dat$item_in_ltm,
      prop = params["prop"],
      prop_ltm = params["prop_ltm"],
      tau = params["tau"],
      gain = params["gain"],
      rate = params["rate"],
      ...
    )
  } else {
    pred <- serial_recall_v2(
      setsize = 3,
      ISI = dat$ISI,
      item_in_ltm = dat$item_in_ltm,
      prop = params["prop"],
      prop_ltm = params["prop_ltm"],
      tau = params["tau"],
      gain = params["gain"],
      rate = params["rate"],
      prim_prop = params["prim_prop"],
      ...
    )
  }
  log_lik <- dbinom(dat$n_correct, dat$n_total, prob = pred, log = TRUE)
  if (exclude_sp1) {
    log_lik <- log_lik[-1]
  }
  -sum(log_lik)
}

overall_deviance <- function(params, split_data, ...,
                             priors = list()) {
  dev <- unlist(lapply(split_data, function(x) calcdev(params, x, ...)))
  out <- sum(dev)

  if (length(priors) > 0) {
    pnames <- names(priors)
    for (i in seq_along(priors)) {
      out <- out - dnorm(params[pnames[i]], mean = priors[[i]]$mean, sd = priors[[i]]$sd, log = TRUE)
    }
  }
  out
}

estimate_model <- function(start, data, two_step = FALSE, priors = list(), simplify = FALSE, by = c("chunk", "gap"), ...) {
  # internal helper functions
  constrain_pars <- function(par) {
    par[c("prop", "prop_ltm", "tau", "rate", "prim_prop")] <- inv_logit(par[c("prop", "prop_ltm", "tau", "rate", "prim_prop")])
    par["gain"] <- inv_logit(par["gain"], lb = 0, ub = 100)
    par
  }

  unconstrain_pars <- function(par) {
    par[c("prop", "prop_ltm", "tau", "rate", "prim_prop")] <- logit(par[c("prop", "prop_ltm", "tau", "rate", "prim_prop")])
    par["gain"] <- logit(par["gain"], lb = 0, ub = 100)
    par
  }

  fn <- function(par, split_data, par2 = NULL, ...) {
    par <- c(par, par2)
    par <- constrain_pars(par)
    overall_deviance(par, split_data, ...)
  }

  start_uc <- unconstrain_pars(start)

  # if two_step is TRUE, nest the optimization (prop, prop_ltm, and rate) within the outer optimization (tau, gain)
  if (two_step) {
    start1 <- start_uc[c("prop", "prop_ltm", "rate", "prim_prop")]
    start2 <- start_uc[c("tau", "gain")]

    fn2 <- function(par, split_data, par2, ...) {
      fit <- optim(
        par = par2,
        fn = fn,
        split_data = split_data,
        control = list(maxit = 1e6),
        par2 = par,
        ...
      )
      environment(fn)$fit_inner <- fit
      fit$value
    }
  } else {
    start1 <- start_uc
    start2 <- NULL
    fn2 <- fn
  }

  groups <- interaction(data[, by])
  split_data <- split(data, groups)

  fit <- optim(
    par = start1,
    fn = fn2,
    split_data = split_data,
    control = list(maxit = 1e6),
    par2 = start2,
    priors = priors,
    ...
  )

  est <- fit$par
  convergence <- fit$convergence
  value <- fit$value
  counts <- fit$counts
  if (two_step) {
    est <- c(est, fit_inner$par)
    concergence <- convergence + fit_inner$convergence
    counts <- counts + fit_inner$counts
  }

  # return the estimated parameters
  est <- constrain_pars(est)
  class(est) <- "serial_recall_pars"
  fit <- structure(
    list(
      start = start,
      par = est,
      convergence = convergence,
      counts = counts,
      value = value
    ),
    class = "serial_recall_fit"
  )

  if (simplify) {
    out <- optimfit_to_df(fit)
    out$fit <- list(fit)
    return(out)
  }

  fit
}

predict.serial_recall_pars <- function(object, data, group_by, type = "response", ...) {
  if (missing(group_by)) {
    pred <- switch(type,
      "response" = serial_recall_v2(
        setsize = nrow(data),
        ISI = data$ISI,
        item_in_ltm = data$item_in_ltm,
        prop = object["prop"],
        prop_ltm = object["prop_ltm"],
        tau = object["tau"],
        gain = object["gain"],
        rate = object["rate"],
        prim_prop = object["prim_prop"],
        ...
      ),
      "strength" = serial_recall_strength(
        setsize = nrow(data),
        ISI = data$ISI,
        item_in_ltm = data$item_in_ltm,
        prop = object["prop"],
        prop_ltm = object["prop_ltm"],
        tau = object["tau"],
        gain = object["gain"],
        rate = object["rate"],
        ...
      )
    )
    return(pred)
  }

  by <- do.call(paste, c(data[, group_by], sep = "_"))
  out <- lapply(split(data, by), function(x) {
    x$pred_tmp_col295 <- predict(object, x, type = type, ...)
    x
  })
  out <- do.call(rbind, out)
  out <- suppressMessages(dplyr::left_join(data, out))
  out$pred_tmp_col295
}

```

One example fit:

```{r}
#| message: false
set.seed(25)
fit <- estimate_model(c(start_fun(), prim_prop = rbeta(1, 5, 2)), exp3_data_agg, version = 2, exclude_sp1 = T)
kableExtra::kable(optimfit_to_df(fit))
exp3_data_agg$pred <- predict(fit, exp3_data_agg, group_by = c("chunk", "gap"))

exp3_data_agg |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)
```

Run the estimation 100 times with different starting values:

```{r}
set.seed(18)
fits <- run_or_load(
  {
    res <- replicate(
      n = 100,
      expr = estimate_model(
        c(start_fun(), prim_prop = rbeta(1, 5, 2)),
        exp3_data_agg,
        version = 2, exclude_sp1 = T
      ),
      simplify = FALSE
    )
    res <- do.call(rbind, lapply(res, optimfit_to_df))
    arrange(res, deviance)
  },
  file = "output/exp3_fits_100_prop_primacy.rds"
)

```

Refine the fits with a second pass starting from the best fit:

```{r}
fits_refined <- run_or_load(
  {
    res <- apply(fits, 1, function(x) {
      fit <- estimate_model(x, exp3_data_agg, version = 2, exclude_sp1 = T)
      optimfit_to_df(fit)
    })
    do.call(rbind, res)
  },
  file = "output/exp3_fits_100_prop_primacy_refined.rds"
)
fits_refined$deviance <- 2 * fits_refined$deviance
```

Print fits sorted by deviance:

```{r}
kableExtra::kable(arrange(fits_refined, deviance))
```

Not as much variance as I thought there would be. Let's limit to the fits within 2 deviance units of the best fit:

Here is the distribution of parameter values that are within 2 deviance units of the best fit:

```{r}
best_fits <- filter(fits_refined, deviance < min(deviance) + 2)

best_fits |>
  select(prop:prim_prop) |>
  pivot_longer(cols = everything(), names_to = "parameter", values_to = "value") |>
  ggplot(aes(x = value)) +
  geom_histogram(bins = 30) +
  facet_wrap(~parameter, scales = "free") +
  theme_pub()
```

Here is the plot of the best fitting model:

```{r}
best_fit <- estimate_model(unlist(best_fits[1, 1:6]), exp3_data_agg, version = 2, exclude_sp1 = T)
exp3_data_agg$pred <- predict(best_fit, exp3_data_agg, group_by = c("chunk", "gap"))

exp3_data_agg |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)
```

and the parameter estimates:

```{r}
kableExtra::kable(t(as.data.frame(best_fit$par)))
```

with this rate, the resource recovery over time looks like this:

```{r}
resources <- function(R, rate, time, r_max = 1) {
  (r_max - R) * (1 - exp(-rate * time))
}

time <- seq(0, 20, 0.1)
tibble(time = time, resources = resources(0, best_fits[1, ]$rate, time)) |>
  ggplot(aes(x = time, y = resources)) +
  geom_line() +
  theme_pub()
```

with 50% of the resource recovered after `r round(-log(0.5) / best_fits[1,]$rate, 2)` seconds

and 25% of the resource recovered after `r round(-log(0.75) / best_fits[1,]$rate, 2)` seconds

## Simulate full trial sequences

### With recovery during entire recall period

Calculate the total recall period duration on each trial:

```{r}
#| label: recall_times
#| fig-width: 5
#| fig-height: 4
total_rts <- exp3_data |>
  group_by(id, trial) |>
  summarize(total_recall_duration = sum(rt) + 9 * 0.2 + 1)

hist(total_rts$total_recall_duration, breaks = 30, col = "grey", border = "white", xlab = "Total recall period duration (s.)", main = "")
```

Create the trial structure for the simulation:

```{r}
#| label: trial_structure
#| message: false
exp3_trial_str <- run_or_load({
  exp3_data |>
  group_by(id, trial) |>
  do({
    aggregate_data(.)
  })
}, file = "output/exp3_trial_structure.rds")

exp3_trial_str <- left_join(exp3_trial_str, total_rts) |> 
  mutate(ISI = ifelse(itemtype == "SP7-9", total_recall_duration, ISI),
         ser_pos = case_when(
          itemtype == "SP1-3" ~ 1,
          itemtype == "SP4-6" ~ 2,
          itemtype == "SP7-9" ~ 3
         ))
```

Run the model without resetting the resource at the start of each trial:

```{r}
#| label: sim_no_reset
#| message: false

serial_recall_full <- function(
    setsize, ISI = rep(0.5, setsize), item_in_ltm = rep(TRUE, setsize), ser_pos = 1:setsize, 
    prop = 0.2, prop_ltm = 0.5, tau = 0.15, gain = 25, rate = 0.1, prim_prop = 1,
    r_max = 1, lambda = 1, growth = "linear") {
  R <- r_max
  p_recall <- vector("numeric", length = setsize)
  prop_ltm <- ifelse(item_in_ltm, prop_ltm, 1)

  for (item in 1:setsize) {
    # strength of the item and recall probability
    strength <- (prop * R)^lambda * prim_prop^(ser_pos[item] - 1)
    p_recall[item] <- 1 / (1 + exp(-(strength - tau) * gain))

    # amount of resources consumed by the item
    r_cost <- strength^(1 / lambda) * prop_ltm[item]
    R <- R - r_cost

    # recover resources
    R <- switch(growth,
      "linear" = min(r_max, R + rate * ISI[item]),
      "asy" = R + (r_max - R) * (1 - exp(-rate * ISI[item]))
    )
  }

  p_recall
}

subj1 <- exp3_trial_str |>
  filter(id == 44125)
```

Here it is for one example subject:

::: {.column-screen}

```{r}
#| message: false
#| fig-width: 12
#| fig-height: 4
subj1$pred <- serial_recall_full(
  setsize = nrow(subj1),
  ISI = subj1$ISI,
  item_in_ltm = subj1$item_in_ltm,
  ser_pos = subj1$ser_pos,
  prop = best_fit$par["prop"],
  prop_ltm = best_fit$par["prop_ltm"],
  tau = best_fit$par["tau"],
  gain = best_fit$par["gain"],
  rate = best_fit$par["rate"],
  prim_prop = best_fit$par["prim_prop"],
  growth = "asy"
)

subj1 |>
  ungroup() |>
  mutate(absolute_position = 1:n()) |>
  ggplot(aes(x = absolute_position, y = pred)) +
  geom_point() +
  geom_line() +
  theme_pub()
```

:::

Same, but collapsed over serial position (one value per trial):

```{r}
#| message: false

subj1 |>
  ungroup() |>
  ggplot(aes(x = trial, y = pred)) +
  stat_summary(geom = "point") +
  geom_smooth() +
  theme_pub()
```

Now simulate the full dataset:

```{r}
#| label: sim_no_reset_full
#| message: false

sim_no_reset_full <- function(data, best_fit) {
  data |>
    group_by(id) |>
    mutate(pred = serial_recall_full(
      setsize = n(),
      ISI = ISI,
      item_in_ltm = item_in_ltm,
      ser_pos = ser_pos,
      prop = best_fit$par["prop"],
      prop_ltm = best_fit$par["prop_ltm"],
      tau = best_fit$par["tau"],
      gain = best_fit$par["gain"],
      rate = best_fit$par["rate"],
      prim_prop = best_fit$par["prim_prop"],
      growth = "asy"
    ))
}

full_sim <- sim_no_reset_full(exp3_trial_str, best_fit)

ggplot(full_sim, aes(x = trial, y = pred)) +
  stat_summary() +
  geom_smooth() +
  theme_pub()
```

Yes, the model predicts worsening performance over trials, but the effect is miniscule.

Plot the original predictions recomputed with the full trial-by-trial model:

```{r}
#| message: false

full_sim |>
  group_by(chunk, gap, itemtype) |>
  summarize(pred = mean(pred),
            p_correct = mean(p_correct)) |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)
```


### No recovery during recall, except during empty screens

```{r}
exp3_trial_str <- run_or_load(
  {
    exp3_data |>
      group_by(id, trial) |>
      do({
        aggregate_data(.)
      })
  },
  file = "output/exp3_trial_structure.rds"
)


exp3_trial_str <- exp3_trial_str |>
  mutate(
    ISI = ifelse(itemtype == "SP7-9", 9 * 0.2 + 1, ISI),
    ser_pos = case_when(
      itemtype == "SP1-3" ~ 1,
      itemtype == "SP4-6" ~ 2,
      itemtype == "SP7-9" ~ 3
    )
  )

full_sim_v2 <- sim_no_reset_full(exp3_trial_str, best_fit)

ggplot(full_sim_v2, aes(x = trial, y = pred)) +
  stat_summary() +
  theme_pub()

filter(full_sim_v2, id == 44125) |>
  mutate(absolute_position = 1:n()) |>
  ggplot(aes(x = absolute_position, y = pred)) +
  geom_point() +
  geom_line() +
  theme_pub()
```


```{r}
#| message: false
full_sim_v2 |>
  group_by(chunk, gap, itemtype) |>
  summarize(pred = mean(pred),
            p_correct = mean(p_correct)) |>
  ggplot(aes(x = gap, y = p_correct, color = chunk)) +
  geom_point(alpha = 0.5) +
  stat_smooth(method = "lm", se = FALSE, linewidth = 0.5) +
  geom_line(aes(y = pred), linetype = "dashed", linewidth = 1.1) +
  scale_color_discrete("First chunk LTM?") +
  facet_wrap(~itemtype)
```