1: -module(batches_SUITE).
    2: -compile([export_all, nowarn_export_all]).
    3: -behaviour(gen_server).
    4: 
    5: -include_lib("stdlib/include/assert.hrl").
    6: -define(mod(N), list_to_atom(atom_to_list(?FUNCTION_NAME) ++ integer_to_list(N))).
    7: 
    8: all() ->
    9:     [
   10:      {group, cache},
   11:      {group, async_workers}
   12:     ].
   13: 
   14: groups() ->
   15:     [
   16:      {cache, [sequence],
   17:       [
   18:        internal_starts_another_cache,
   19:        external_does_not_start_another_cache,
   20:        internal_stop_does_stop_the_cache,
   21:        external_stop_does_nothing,
   22:        shared_cache_inserts_in_shared_table
   23:       ]},
   24:      {async_workers, [sequence],
   25:       [
   26:        broadcast_reaches_all_workers,
   27:        broadcast_reaches_all_keys,
   28:        filled_batch_raises_batch_metric,
   29:        unfilled_batch_raises_flush_metric,
   30:        timeouts_and_canceled_timers_do_not_need_to_log_messages,
   31:        prepare_task_works,
   32:        sync_flushes_down_everything,
   33:        sync_aggregates_down_everything,
   34:        aggregating_error_is_handled_and_can_continue,
   35:        aggregation_might_produce_noop_requests,
   36:        async_request,
   37:        retry_request,
   38:        retry_request_cancelled,
   39:        retry_request_cancelled_in_verify_function,
   40:        ignore_msg_when_waiting_for_reply,
   41:        async_request_fails
   42:       ]}
   43:     ].
   44: 
   45: init_per_suite(Config) ->
   46:     application:ensure_all_started(telemetry),
   47:     meck:new(mongoose_metrics, [stub_all, no_link]),
   48:     Config.
   49: 
   50: end_per_suite(_Config) ->
   51:     meck:unload().
   52: 
   53: init_per_group(_, Config) ->
   54:     Config.
   55: 
   56: end_per_group(_, _Config) ->
   57:     ok.
   58: 
   59: init_per_testcase(_TestCase, Config) ->
   60:     pg:start_link(mim_scope),
   61:     mim_ct_sup:start_link(ejabberd_sup),
   62:     meck:new(gen_mod, [passthrough]),
   63:     Config.
   64: 
   65: end_per_testcase(_TestCase, _Config) ->
   66:     meck:unload(gen_mod),
   67:     ok.
   68: 
   69: cache_config() ->
   70:     config_parser_helper:default_mod_config(mod_cache_users).
   71: 
   72: cache_config(internal) ->
   73:     Def = config_parser_helper:default_mod_config(mod_cache_users),
   74:     Def#{module => internal};
   75: cache_config(Module) ->
   76:     #{module => Module}.
   77: 
   78: %% Tests
   79: internal_starts_another_cache(_) ->
   80:     mongoose_user_cache:start_new_cache(host_type(), ?mod(1), cache_config()),
   81:     mongoose_user_cache:start_new_cache(host_type(), ?mod(2), cache_config(internal)),
   82:     L = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
   83:     ?assertEqual(2, length(L)).
   84: 
   85: external_does_not_start_another_cache(_) ->
   86:     mongoose_user_cache:start_new_cache(host_type(), ?mod(1), cache_config()),
   87:     mongoose_user_cache:start_new_cache(host_type(), ?mod(2), cache_config(?mod(1))),
   88:     L = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
   89:     ?assertEqual(1, length(L)).
   90: 
   91: internal_stop_does_stop_the_cache(_) ->
   92:     meck:expect(gen_mod, get_module_opt, fun(_, _, module, _) -> internal end),
   93:     mongoose_user_cache:start_new_cache(host_type(), ?mod(1), cache_config()),
   94:     mongoose_user_cache:start_new_cache(host_type(), ?mod(2), cache_config(internal)),
   95:     L1 = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
   96:     ct:pal("Value ~p~n", [L1]),
   97:     mongoose_user_cache:stop_cache(host_type(), ?mod(2)),
   98:     L2 = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
   99:     ct:pal("Value ~p~n", [L2]),
  100:     ?assertNotEqual(L1, L2).
  101: 
  102: external_stop_does_nothing(_) ->
  103:     meck:expect(gen_mod, get_module_opt, fun(_, _, module, _) -> ?mod(1) end),
  104:     mongoose_user_cache:start_new_cache(host_type(), ?mod(1), cache_config()),
  105:     mongoose_user_cache:start_new_cache(host_type(), ?mod(2), cache_config(?mod(1))),
  106:     L1 = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
  107:     mongoose_user_cache:stop_cache(host_type(), ?mod(2)),
  108:     L2 = [S || S = {_Name, _Pid, worker, [segmented_cache]} <- supervisor:which_children(ejabberd_sup)],
  109:     ?assertEqual(L1, L2).
  110: 
  111: shared_cache_inserts_in_shared_table(_) ->
  112:     meck:expect(gen_mod, get_module_opt, fun(_, _, module, _) -> ?mod(1) end),
  113:     mongoose_user_cache:start_new_cache(host_type(), ?mod(1), cache_config()),
  114:     mongoose_user_cache:start_new_cache(host_type(), ?mod(2), cache_config(?mod(1))),
  115:     mongoose_user_cache:merge_entry(host_type(), ?mod(2), some_jid(), #{}),
  116:     ?assert(mongoose_user_cache:is_member(host_type(), ?mod(1), some_jid())).
  117: 
  118: aggregation_might_produce_noop_requests(_) ->
  119:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  120:     Requestor = fun(1, _) -> timer:sleep(1), gen_server:send_request(Server, 1);
  121:                    (_, _) -> drop end,
  122:     Opts = (default_aggregator_opts(Server))#{pool_id => ?FUNCTION_NAME,
  123:                                               request_callback => Requestor},
  124:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  125:     [ gen_server:cast(Pid, {task, key, N}) || N <- lists:seq(1, 1000) ],
  126:     async_helper:wait_until(
  127:       fun() -> gen_server:call(Server, get_acc) end, 1).
  128: 
  129: broadcast_reaches_all_workers(_) ->
  130:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  131:     WPoolOpts = (default_aggregator_opts(Server))#{pool_type => aggregate,
  132:                                                    pool_size => 10},
  133:     {ok, _} = mongoose_async_pools:start_pool(host_type(), ?FUNCTION_NAME, WPoolOpts),
  134:     mongoose_async_pools:broadcast_task(host_type(), ?FUNCTION_NAME, key, 1),
  135:     async_helper:wait_until(
  136:       fun() -> gen_server:call(Server, get_acc) end, 10).
  137: 
  138: broadcast_reaches_all_keys(_) ->
  139:     HostType = host_type(),
  140:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  141:     Tid = ets:new(table, [public, {read_concurrency, true}]),
  142:     Req = fun(Task, _) ->
  143:                   case ets:member(Tid, continue) of
  144:                       true ->
  145:                           gen_server:send_request(Server, Task);
  146:                       false ->
  147:                           async_helper:wait_until(fun() -> ets:member(Tid, continue) end, true),
  148:                           gen_server:send_request(Server, 0)
  149:                   end
  150:           end,
  151:     WPoolOpts = (default_aggregator_opts(Server))#{pool_type => aggregate,
  152:                                                    pool_size => 3,
  153:                                                    request_callback => Req},
  154:     {ok, _} = mongoose_async_pools:start_pool(HostType, ?FUNCTION_NAME, WPoolOpts),
  155:     [ mongoose_async_pools:put_task(HostType, ?FUNCTION_NAME, N, 1) || N <- lists:seq(0, 1000) ],
  156:     mongoose_async_pools:broadcast(HostType, ?FUNCTION_NAME, -1),
  157:     ets:insert(Tid, {continue, true}),
  158:     async_helper:wait_until(
  159:       fun() -> gen_server:call(Server, get_acc) end, 0).
  160: 
  161: filled_batch_raises_batch_metric(_) ->
  162:     Opts = #{host_type => host_type(),
  163:              pool_id => ?FUNCTION_NAME,
  164:              batch_size => 1,
  165:              flush_interval => 1000,
  166:              flush_callback => fun(_, _) -> ok end,
  167:              flush_extra => #{host_type => host_type(), queue_length => 0}},
  168:     {ok, Pid} = gen_server:start_link(mongoose_batch_worker, Opts, []),
  169:     gen_server:cast(Pid, {task, key, ok}),
  170:     MetricName = [mongoose_async_pools, '_', batch_flushes],
  171:     async_helper:wait_until(
  172:       fun() -> 0 < meck:num_calls(mongoose_metrics, update, ['_', MetricName, '_']) end, true).
  173: 
  174: unfilled_batch_raises_flush_metric(_) ->
  175:     Opts = #{host_type => host_type(),
  176:              pool_id => ?FUNCTION_NAME,
  177:              batch_size => 1000,
  178:              flush_interval => 5,
  179:              flush_callback => fun(_, _) -> ok end,
  180:              flush_extra => #{host_type => host_type(), queue_length => 0}},
  181:     {ok, Pid} = gen_server:start_link(mongoose_batch_worker, Opts, []),
  182:     gen_server:cast(Pid, {task, key, ok}),
  183:     MetricName = [mongoose_async_pools, '_', timed_flushes],
  184:     async_helper:wait_until(
  185:       fun() -> 0 < meck:num_calls(mongoose_metrics, update, ['_', MetricName, '_']) end, true).
  186: 
  187: timeouts_and_canceled_timers_do_not_need_to_log_messages(_) ->
  188:     Timeout = 10,
  189:     QueueSize = 2,
  190:     meck:new(logger, [passthrough, unstick]),
  191:     Opts = #{host_type => host_type(),
  192:              pool_id => ?FUNCTION_NAME,
  193:              batch_size => QueueSize,
  194:              flush_interval => Timeout,
  195:              flush_callback => fun(_, _) -> ok end,
  196:              flush_extra => #{host_type => host_type(), queue_length => 0}},
  197:     {ok, Pid} = gen_server:start_link(mongoose_batch_worker, Opts, []),
  198:     [ gen_server:cast(Pid, {task, ok}) || _ <- lists:seq(1, QueueSize) ],
  199:     ct:sleep(Timeout*2),
  200:     ?assertEqual(0, meck:num_calls(logger, macro_log, '_')).
  201: 
  202: prepare_task_works(_) ->
  203:     Timeout = 1000,
  204:     QueueSize = 2,
  205:     T = self(),
  206:     meck:new(logger, [passthrough, unstick]),
  207:     Opts = #{host_type => host_type(),
  208:              pool_id => ?FUNCTION_NAME,
  209:              batch_size => QueueSize,
  210:              flush_interval => Timeout,
  211:              prep_callback => fun(0, _) -> {error, bad};
  212:                                  (A, _) -> {ok, A + 1}
  213:                               end,
  214:              flush_callback => fun(Tasks, _) -> T ! {tasks, Tasks}, ok end,
  215:              flush_extra => #{host_type => host_type(), queue_length => 0}},
  216:     {ok, Pid} = gen_server:start_link(mongoose_batch_worker, Opts, []),
  217:     [ gen_server:cast(Pid, {task, N}) || N <- lists:seq(0, QueueSize) ],
  218:     receive
  219:         {tasks, Tasks} ->
  220:             ?assertEqual([ N + 1 || N <- lists:seq(1, QueueSize) ], Tasks)
  221:     after
  222:         Timeout*2 -> ct:fail(no_answer_received)
  223:     end,
  224:     ?assert(0 < meck:num_calls(logger, macro_log, '_')).
  225: 
  226: sync_flushes_down_everything(_) ->
  227:     Opts = #{host_type => host_type(),
  228:              pool_id => ?FUNCTION_NAME,
  229:              batch_size => 5000,
  230:              flush_interval => 5000,
  231:              flush_callback => fun(_, _) -> ok end,
  232:              flush_extra => #{host_type => host_type(), queue_length => 0}},
  233:     {ok, Pid} = gen_server:start_link(mongoose_batch_worker, Opts, []),
  234:     ?assertEqual(skipped, gen_server:call(Pid, sync)),
  235:     gen_server:cast(Pid, {task, key, ok}),
  236:     ?assertEqual(ok, gen_server:call(Pid, sync)),
  237:     MetricName = [mongoose_async_pools, '_', timed_flushes],
  238:     ?assert(0 < meck:num_calls(mongoose_metrics, update, ['_', MetricName, '_'])).
  239: 
  240: sync_aggregates_down_everything(_) ->
  241:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  242:     Opts = (default_aggregator_opts(Server))#{pool_id => ?FUNCTION_NAME},
  243:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  244:     ?assertEqual(skipped, gen_server:call(Pid, sync)),
  245:     [ gen_server:cast(Pid, {task, key, N}) || N <- lists:seq(1, 1000) ],
  246:     ?assertEqual(ok, gen_server:call(Pid, sync)),
  247:     ?assertEqual(500500, gen_server:call(Server, get_acc)).
  248: 
  249: aggregating_error_is_handled_and_can_continue(_) ->
  250:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  251:     Requestor = fun(Task, _) -> timer:sleep(1), gen_server:send_request(Server, Task) end,
  252:     Opts = (default_aggregator_opts(Server))#{pool_id => ?FUNCTION_NAME,
  253:                                               request_callback => Requestor},
  254:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  255:     [ gen_server:cast(Pid, {task, key, N}) || N <- lists:seq(1, 10) ],
  256:     gen_server:cast(Pid, {task, return_error, return_error}),
  257:     ct:sleep(100),
  258:     [ gen_server:cast(Pid, {task, key, N}) || N <- lists:seq(11, 100) ],
  259:     %% We don't call sync here because sync is force flushing,
  260:     %% we want to test that it flushes alone
  261:     ct:sleep(100),
  262:     ?assert(55 < gen_server:call(Server, get_acc)).
  263: 
  264: async_request(_) ->
  265:     {ok, Server} = gen_server:start_link(?MODULE, [], []),
  266:     Opts = (default_aggregator_opts(Server))#{pool_id => ?FUNCTION_NAME},
  267:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  268:     [ gen_server:cast(Pid, {task, key, N}) || N <- lists:seq(1, 1000) ],
  269:     async_helper:wait_until(
  270:       fun() -> gen_server:call(Server, get_acc) end, 500500).
  271: 
  272: retry_request(_) ->
  273:     Opts = (retry_aggregator_opts())#{pool_id => retry_request},
  274:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  275:     gen_server:cast(Pid, {task, key, 1}),
  276:     receive_task_called(0),
  277:     receive_task_called(1),
  278:     gen_server:cast(Pid, {task, key, 1}),
  279:     receive_task_called(0),
  280:     receive_task_called(1),
  281:     ensure_no_tasks_to_receive().
  282: 
  283: retry_request_cancelled(_) ->
  284:     Opts = (retry_aggregator_opts())#{pool_id => retry_request_cancelled,
  285:                                       request_callback => fun do_cancel_request/2},
  286:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  287:     gen_server:cast(Pid, {task, key, 1}),
  288:     receive_task_called(0),
  289:     %% 3 retries
  290:     receive_task_called(1),
  291:     receive_task_called(2),
  292:     receive_task_called(3),
  293:     ensure_no_tasks_to_receive(),
  294:     %% Second task gets started
  295:     gen_server:cast(Pid, {task, key, 2}),
  296:     receive_task_called(0).
  297: 
  298: retry_request_cancelled_in_verify_function(_) ->
  299:     Opts = (retry_aggregator_opts())#{pool_id => retry_request_cancelled_in_verify_function,
  300:                                       request_callback => fun do_request/2,
  301:                                       verify_callback => fun validate_all_fails/3},
  302:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  303:     gen_server:cast(Pid, {task, key, 1}),
  304:     receive_task_called(0),
  305:     %% 3 retries
  306:     receive_task_called(1),
  307:     receive_task_called(2),
  308:     receive_task_called(3),
  309:     ensure_no_tasks_to_receive(),
  310:     %% Second task gets started
  311:     gen_server:cast(Pid, {task, key, 2}),
  312:     receive_task_called(0).
  313: 
  314: ignore_msg_when_waiting_for_reply(_) ->
  315:     Opts = (retry_aggregator_opts())#{pool_id => ignore_msg_when_waiting_for_reply,
  316:                                       request_callback => fun do_request_but_ignore_other_messages/2},
  317:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  318:     gen_server:cast(Pid, {task, key, 1}),
  319:     receive_task_called(0),
  320:     %% Second task gets started
  321:     gen_server:cast(Pid, {task, key, 2}),
  322:     receive_task_called(0).
  323: 
  324: async_request_fails(_) ->
  325:     %% Does request that crashes the gen_server, but not the aggregator
  326:     {ok, Server} = gen_server:start({local, async_req_fails_server}, ?MODULE, [], []),
  327:     Ref = monitor(process, Server),
  328:     Opts = (default_aggregator_opts(async_req_fails_server))#{pool_id => ?FUNCTION_NAME},
  329:     {ok, Pid} = gen_server:start_link(mongoose_aggregator_worker, Opts, []),
  330:     gen_server:cast(Pid, {task, key, {ack_and_die, self()}}),
  331:     %% Acked and the server dies
  332:     receive
  333:         {'DOWN', R, process, _, _} when R =:= Ref -> ok
  334:         after 5000 -> error(down_receive_timeout)
  335:     end,
  336:     receive
  337:         {acked, S} when S =:= Server -> ok
  338:         after 5000 -> error(acked_receive_timeout)
  339:     end,
  340:     %% Eventually the task is cancelled
  341:     async_helper:wait_until(fun() -> element(4, sys:get_state(Pid)) end, no_request_pending),
  342:     %% Check that aggregator still processes new tasks
  343:     %% Start the task and wait for processing
  344:     {ok, Server2} = gen_server:start({local, async_req_fails_server}, ?MODULE, [], []),
  345:     gen_server:cast(Pid, {task, key, {ack, self()}}),
  346:     receive
  347:         {acked, S2} when S2 =:= Server2 -> ok
  348:         after 5000 -> error(acked_receive_timeout)
  349:     end,
  350:     %% Check state
  351:     1 = gen_server:call(Server2, get_acc).
  352: 
  353: %% helpers
  354: host_type() ->
  355:     <<"HostType">>.
  356: 
  357: some_jid() ->
  358:     jid:make_noprep(<<"alice">>, <<"localhost">>, <<>>).
  359: 
  360: default_aggregator_opts(Server) ->
  361:     #{host_type => host_type(),
  362:       request_callback => requester(Server),
  363:       aggregate_callback => fun aggregate_sum/3,
  364:       verify_callback => fun validate_all_ok/3,
  365:       flush_extra => #{host_type => host_type()}}.
  366: 
  367: retry_aggregator_opts() ->
  368:     #{host_type => host_type(),
  369:       request_callback => fun do_retry_request/2,
  370:       aggregate_callback => fun aggregate_sum/3,
  371:       verify_callback => fun validate_ok/3,
  372:       flush_extra => #{host_type => host_type(), origin_pid => self()}}.
  373: 
  374: validate_all_ok(ok, _, _) ->
  375:     ok.
  376: 
  377: validate_ok(ok, _, _) ->
  378:     ok;
  379: validate_ok({error, Reason}, _, _) ->
  380:     {error, Reason}.
  381: 
  382: validate_all_fails(_, _, _) ->
  383:     {error, all_fails}.
  384: 
  385: aggregate_sum(T1, T2, _) ->
  386:     {ok, T1 + T2}.
  387: 
  388: requester(Server) ->
  389:     fun(return_error, _) ->
  390:             gen_server:send_request(Server, return_error);
  391:        ({ack_and_die, _} = Task, _) ->
  392:             gen_server:send_request(Server, Task);
  393:        ({ack, _} = Task, _) ->
  394:             gen_server:send_request(Server, Task);
  395:        (Task, _) ->
  396:             timer:sleep(1), gen_server:send_request(Server, Task)
  397:     end.
  398: 
  399: %% Fails first task
  400: do_retry_request(_Task, #{origin_pid := Pid, retry_number := Retry}) ->
  401:     Pid ! {task_called, Retry},
  402:     Ref = make_ref(),
  403:     Reply = case Retry of 0 -> {error, simulate_error}; 1 -> ok end,
  404:     %% Simulate gen_server call reply
  405:     self() ! {[alias|Ref], Reply},
  406:     Ref.
  407: 
  408: do_request(_Task, #{origin_pid := Pid, retry_number := Retry}) ->
  409:     Pid ! {task_called, Retry},
  410:     Ref = make_ref(),
  411:     Reply = ok,
  412:     %% Simulate gen_server call reply
  413:     self() ! {[alias|Ref], Reply},
  414:     Ref.
  415: 
  416: %% Fails all tries
  417: do_cancel_request(_Task, #{origin_pid := Pid, retry_number := Retry}) ->
  418:     Pid ! {task_called, Retry},
  419:     Ref = make_ref(),
  420:     Reply = {error, simulate_error},
  421:     %% Simulate gen_server call reply
  422:     self() ! {[alias|Ref], Reply},
  423:     Ref.
  424: 
  425: %% Fails all tries
  426: do_request_but_ignore_other_messages(_Task, #{origin_pid := Pid, retry_number := Retry}) ->
  427:     Pid ! {task_called, Retry},
  428:     Ref = make_ref(),
  429:     Reply = ok,
  430:     %% Just send an unexpected messages which should be ignored
  431:     self() ! unexpected_msg_should_be_ignored,
  432:     %% Simulate gen_server call reply
  433:     self() ! {[alias|Ref], Reply},
  434:     Ref.
  435: 
  436: init([]) ->
  437:     {ok, 0}.
  438: 
  439: handle_call(get_acc, _From, Acc) ->
  440:     {reply, Acc, Acc};
  441: handle_call(return_error, _From, Acc) ->
  442:     {reply, {error, return_error}, Acc};
  443: handle_call({ack_and_die, Pid}, _From, Acc) ->
  444:     Pid ! {acked, self()},
  445:     error(oops);
  446: handle_call({ack, Pid}, _From, Acc) ->
  447:     Pid ! {acked, self()},
  448:     {reply, ok, 1 + Acc};
  449: handle_call(N, _From, Acc) ->
  450:     {reply, ok, N + Acc}.
  451: 
  452: handle_cast(_Msg, Acc) ->
  453:     {noreply, Acc}.
  454: 
  455: receive_task_called(ExpectedRetryNumber) ->
  456:     receive
  457:         {task_called, RetryNum} ->
  458:             ExpectedRetryNumber = RetryNum
  459:         after 5000 ->
  460:             error(timeout)
  461:     end.
  462: 
  463: ensure_no_tasks_to_receive() ->
  464:     receive {task_called, _} -> error(unexpected_msg) after 0 -> ok end.