1: -module(mod_websockets_SUITE).
    2: -compile([export_all, nowarn_export_all]).
    3: -include_lib("eunit/include/eunit.hrl").
    4: -define(HANDSHAKE_TIMEOUT, 3000).
    5: -define(eq(E, I), ?assertEqual(E, I)).
    6: -define(PORT, 5280).
    7: -define(IP, {127, 0, 0, 1}).
    8: -define(FAST_PING_RATE, 500).
    9: -define(NEW_TIMEOUT, 1200).
   10: %The timeout is long enough to pass all test cases for ping interval settings
   11: %using NEW_TIMEOUT value. In these tests we wait for at most 2 pings.
   12: %The 300ms is just an additional overhead
   13: -define(IDLE_TIMEOUT, ?NEW_TIMEOUT * 2 + 300).
   14: 
   15: 
   16: all() ->
   17:     ping_tests() ++ subprotocol_header_tests() ++ timeout_tests().
   18: 
   19: ping_tests() ->
   20:     [ping_test,
   21:      set_ping_test,
   22:      disable_ping_test,
   23:      disable_and_set].
   24: 
   25: subprotocol_header_tests() ->
   26:     [agree_to_xmpp_subprotocol,
   27:      agree_to_xmpp_subprotocol_case_insensitive,
   28:      agree_to_xmpp_subprotocol_from_many,
   29:      do_not_agree_to_missing_subprotocol,
   30:      do_not_agree_to_other_subprotocol].
   31: 
   32: timeout_tests() ->
   33:     [connection_is_closed_after_idle_timeout,
   34:      client_ping_frame_resets_idle_timeout].
   35: 
   36: init_per_suite(C) ->
   37:     setup(),
   38:     C.
   39: 
   40: end_per_suite(_) ->
   41:     teardown(),
   42:     ok.
   43: 
   44: init_per_testcase(_, C) ->
   45:     C.
   46: 
   47: end_per_testcase(_, C) ->
   48:     C.
   49: 
   50: setup() ->
   51:     meck:unload(),
   52:     application:ensure_all_started(cowboy),
   53:     application:ensure_all_started(jid),
   54:     meck:new(supervisor, [unstick, passthrough, no_link]),
   55:     meck:new(gen_mod,[unstick, passthrough, no_link]),
   56:     %% Set ping rate
   57:     meck:expect(gen_mod,get_opt, fun(ping_rate, _, none) -> ?FAST_PING_RATE;
   58:                                     (A, B, C) -> meck:passthrough([A, B, C]) end),
   59:     meck:expect(supervisor, start_child,
   60:                 fun(ejabberd_listeners, {_, {_, start_link, [_]}, transient,
   61:                                          infinity, worker, [_]}) -> {ok, self()};
   62:                    (A,B) -> meck:passthrough([A,B])
   63:                 end),
   64:     %% Start websocket cowboy listening
   65: 
   66:     Opts = [{num_acceptors, 10},
   67:             {max_connections, 1024},
   68:             {modules, [{"_", "/http-bind", mod_bosh},
   69:                        {"_", "/ws-xmpp", mod_websockets,
   70:                         [{timeout, ?IDLE_TIMEOUT},
   71:                          {ping_rate, ?FAST_PING_RATE}]}]}],
   72:     ejabberd_cowboy:start_listener({?PORT, ?IP, tcp}, Opts).
   73: 
   74: 
   75: teardown() ->
   76:     meck:unload(),
   77:     cowboy:stop_listener(ejabberd_cowboy:ref({?PORT, ?IP, tcp})),
   78:     application:stop(cowboy),
   79:     %% Do not stop jid, Erlang 21 does not like to reload nifs
   80:     ok.
   81: 
   82: ping_test(_Config) ->
   83:     timer:sleep(500),
   84:     #{socket := Socket1} = ws_handshake(),
   85:     %% When
   86:     Resp = wait_for_ping(Socket1, 0, 5000),
   87:     %% then
   88:     ?eq(Resp, ok).
   89: 
   90: set_ping_test(_Config) ->
   91:     #{socket := Socket1, internal_socket := InternalSocket} = ws_handshake(),
   92:     %% When
   93:     mod_websockets:set_ping(InternalSocket, ?NEW_TIMEOUT),
   94:     WaitMargin = 300,
   95:     ok = wait_for_ping(Socket1, 0 , ?NEW_TIMEOUT + WaitMargin),
   96:     %% Im waiting too less time!
   97:     TooShort = 700,
   98:     ErrorTimeout = wait_for_ping(Socket1, 0, TooShort),
   99:     ?eq({error, timeout}, ErrorTimeout),
  100:     %% now I'm wait the remaining time (and some margin)
  101:     ok = wait_for_ping(Socket1, 0, ?NEW_TIMEOUT - TooShort + WaitMargin).
  102: 
  103: disable_ping_test(_Config) ->
  104:     #{socket := Socket1, internal_socket := InternalSocket} = ws_handshake(),
  105:     %% When
  106:     mod_websockets:disable_ping(InternalSocket),
  107:     %% Should not receive any packets
  108:     ErrorTimeout = wait_for_ping(Socket1, 0, ?FAST_PING_RATE),
  109:     %% then
  110:     ?eq(ErrorTimeout, {error, timeout}).
  111: 
  112: disable_and_set(_Config) ->
  113:     #{socket := Socket1, internal_socket := InternalSocket} = ws_handshake(),
  114:     %% When
  115:     mod_websockets:disable_ping(InternalSocket),
  116:     %% Should not receive any packets
  117:     ErrorTimeout = wait_for_ping(Socket1, 0, ?FAST_PING_RATE),
  118:     mod_websockets:set_ping(InternalSocket, ?NEW_TIMEOUT),
  119:     Resp1 = wait_for_ping(Socket1, 0, ?NEW_TIMEOUT + 100),
  120:     %% then
  121:     ?eq(ErrorTimeout, {error, timeout}),
  122:     ?eq(Resp1, ok).
  123: 
  124: connection_is_closed_after_idle_timeout(_Config) ->
  125:     #{socket := Socket} = ws_handshake(),
  126:     inet:setopts(Socket, [{active, true}]),
  127:     Closed = wait_for_close(Socket),
  128:     ?eq(Closed, ok).
  129: 
  130: client_ping_frame_resets_idle_timeout(_Config) ->
  131:     #{socket := Socket} = ws_handshake(#{extra_headers => [<<"sec-websocket-protocol: xmpp\r\n">>]}),
  132:     Now = os:system_time(millisecond),
  133:     inet:setopts(Socket, [{active, true}]),
  134:     WaitBeforePingFrame = (?IDLE_TIMEOUT) div 2,
  135:     timer:sleep(WaitBeforePingFrame),
  136:     %%Masked ping frame
  137:     Ping = << 1:1, 0:3, 9:4, 1:1, 0:39 >>,
  138:     ok = gen_tcp:send(Socket, Ping),
  139:     Closed = wait_for_close(Socket),
  140:     ?eq(Closed, ok),
  141:     End = os:system_time(millisecond),
  142:     %%Below we check if the time difference between now and the moment
  143:     %%the WebSocket was established is bigger then the the ?IDLE_TIMEOUT plus initial wait time
  144:     %%This shows that the connection was not killed after the first ?IDLE_TIMEOUT
  145:     ?assert(End - Now > ?IDLE_TIMEOUT + WaitBeforePingFrame).
  146: 
  147: wait_for_close(Socket) ->
  148:     receive
  149:         {tcp_closed, Socket} ->
  150:             ok
  151:     after ?IDLE_TIMEOUT + 500 ->
  152:               timeout
  153:     end.
  154: 
  155: 
  156: ws_handshake() ->
  157:     ws_handshake(#{}).
  158: 
  159: %% Client side
  160: %% Gun is too high level for subprotocol_header_tests checks
  161: ws_handshake(Opts) ->
  162:     Host = "localhost",
  163:     Port = ?PORT,
  164:     {ok, Socket} = gen_tcp:connect(Host, Port, [binary, {packet, raw},
  165:                                                 {active, false}]),
  166:     ok = gen_tcp:send(Socket,
  167:                       ["GET /ws-xmpp HTTP/1.1\r\n"
  168:                        "Host: localhost:5280\r\n"
  169:                        "Connection: upgrade\r\n"
  170:                        "Origin: http://localhost\r\n"
  171:                        "Sec-WebSocket-Key: NT1P6NvEFQyDDKuTyEN+1Q==\r\n"
  172:                        "Sec-WebSocket-Version: 13\r\n"
  173:                        "Upgrade: websocket\r\n"
  174:                        ++ maps:get(extra_headers, Opts, ""),
  175:                        "\r\n"]),
  176:     {ok, Handshake} = gen_tcp:recv(Socket, 0, 5000),
  177:     Packet = erlang:decode_packet(http, Handshake, []),
  178:     {ok, {http_response, {1,1}, 101, "Switching Protocols"}, Rest} = Packet,
  179:     {Headers, _} = consume_headers(Rest, []),
  180:     InternalSocket = get_websocket(),
  181:     #{socket => Socket, internal_socket => InternalSocket, headers => Headers}.
  182: 
  183: consume_headers(Data, Headers) ->
  184:     case erlang:decode_packet(httph, Data, []) of
  185:         {ok, http_eoh, Rest} ->
  186:             {lists:reverse(Headers), Rest};
  187:         {ok, {http_header,_,Name,_,Value}, Rest} ->
  188:             consume_headers(Rest, [{Name, Value}|Headers])
  189:     end.
  190: 
  191: wait_for_ping(_, Try, _) when Try > 10 ->
  192:     {error, no_ping_packet};
  193: wait_for_ping(Socket, Try, Timeout) ->
  194:     {Reply, Content} = gen_tcp:recv(Socket, 0, Timeout),
  195:     case Reply of
  196:         error ->
  197:             {error, Content};
  198:         ok ->
  199:             Ping = ws_rx_frame(<<"">>, 9),
  200:             case Content of
  201:                 Ping ->
  202:                     ok;
  203:                 _ ->
  204:                     wait_for_ping(Socket, Try + 1, Timeout)
  205:             end
  206:     end.
  207: 
  208: %% Helpers
  209: ws_rx_frame(Payload, Opcode) ->
  210:     Length = byte_size(Payload),
  211:     <<1:1, 0:3, Opcode:4, 0:1, Length:7, Payload/binary>>.
  212: 
  213: get_websocket() ->
  214:     %% Assumption: there's only one ranch protocol process running and
  215:     %% it's the one which started due to our gen_tcp:connect in ws_handshake/1
  216:     [{cowboy_clear, Pid}] = get_ranch_connections(),
  217:     %% This is a record! See mod_websockets: #websocket{}.
  218:     {websocket, Pid, fake_peername, undefined}.
  219: 
  220: get_child_by_mod(Sup, Mod) ->
  221:     Kids = supervisor:which_children(Sup),
  222:     case lists:keyfind([Mod], 4, Kids) of
  223:         false -> error(not_found, [Sup, Mod]);
  224:         {_, KidPid, _, _} -> KidPid
  225:     end.
  226: 
  227: get_ranch_connections() ->
  228:     LSup = get_child_by_mod(ranch_sup, ranch_listener_sup),
  229:     CSup = get_child_by_mod(LSup, ranch_conns_sup),
  230:     [ {Mod, Pid} || {_, Pid, _, [Mod]} <- supervisor:which_children(CSup) ].
  231: 
  232: wait_for_no_ranch_connections(Times) ->
  233:     case get_ranch_connections() of
  234:         [] ->
  235:             ok;
  236:         _ when Times > 0 ->
  237:             timer:sleep(100),
  238:             wait_for_no_ranch_connections(Times - 1);
  239:        Connections ->
  240:             error(#{reason => wait_for_no_ranch_connections_failed,
  241:                     connections => Connections})
  242:     end.
  243: 
  244: %% ---------------------------------------------------------------------
  245: %% subprotocol_header_tests test functions
  246: %% ---------------------------------------------------------------------
  247: 
  248: %% From RFC 6455:
  249: %%   The |Sec-WebSocket-Protocol| header field is used in the WebSocket
  250: %%   opening handshake.  It is sent from the client to the server and back
  251: %%   from the server to the client to confirm the subprotocol of the
  252: %%   connection.
  253: agree_to_xmpp_subprotocol(_) ->
  254:     check_subprotocol("Proper client behaviour", ["xmpp"], "xmpp").
  255: 
  256: agree_to_xmpp_subprotocol_case_insensitive(_) ->
  257:     %% The value must conform to the requirements
  258:     %% given in item 10 of Section 4.1 of this specification -- namely,
  259:     %% the value must be a token as defined by RFC 2616 [RFC2616].
  260:     %% ...
  261:     %% 10.  The request MAY include a header field with the name
  262:     %%      |Sec-WebSocket-Protocol|.  If present, this value indicates one
  263:     %%      or more comma-separated subprotocol the client wishes to speak,
  264:     %%      ordered by preference.  The elements that comprise this value
  265:     %%      MUST be non-empty strings with characters in the range U+0021 to
  266:     %%      U+007E not including separator characters as defined in
  267:     %%      [RFC2616] and MUST all be unique strings.  The ABNF for the
  268:     %%      value of this header field is 1#token, where the definitions of
  269:     %%      constructs and rules are as given in [RFC2616].
  270:     %% ...
  271:     check_subprotocol("Case insensitive", ["XMPP"], "XMPP").
  272: 
  273: %% From RFC 6455:
  274: %%   The |Sec-WebSocket-Protocol| header field MAY appear multiple times
  275: %%   in an HTTP request (which is logically the same as a single
  276: %%   |Sec-WebSocket-Protocol| header field that contains all values).
  277: %%   However, the |Sec-WebSocket-Protocol| header field MUST NOT appear
  278: %%   more than once in an HTTP response.
  279: agree_to_xmpp_subprotocol_from_many(_) ->
  280:     check_subprotocol("Two protocols in one header", ["xmpp, other"], "xmpp"),
  281:     check_subprotocol("Two protocols in one header", ["other, xmpp"], "xmpp"),
  282:     check_subprotocol("Two protocols in two headers", ["other", "xmpp"], "xmpp").
  283: 
  284: %% Do not set a Sec-Websocket-Protocol header in response if it's missing in a request.
  285: %%
  286: %% From RFC 6455:
  287: %%   if the server does not wish to agree to one of the suggested
  288: %%   subprotocols, it MUST NOT send back a |Sec-WebSocket-Protocol|
  289: %%   header field in its response.
  290: do_not_agree_to_missing_subprotocol(_) ->
  291:     check_subprotocol("Subprotocol header is missing", [], undefined).
  292: 
  293: %% Do not set a Sec-Websocket-Protocol header in response if it's provided, but not xmpp.
  294: do_not_agree_to_other_subprotocol(_) ->
  295:     check_subprotocol("Subprotocol is not xmpp", ["other"], undefined).
  296: 
  297: 
  298: %% ---------------------------------------------------------------------
  299: %% subprotocol_header_tests helper functions
  300: %% ---------------------------------------------------------------------
  301: 
  302: check_subprotocol(Comment, ProtoList, ExpectedProtocol) ->
  303:     ReqHeaders = lists:append(["Sec-Websocket-Protocol: " ++ Proto ++ "\r\n" || Proto <- ProtoList]),
  304:     Info = #{reason => check_subprotocol_failed,
  305:              comment => Comment,
  306:              expected_protocol => ExpectedProtocol,
  307:              request_headers => ReqHeaders},
  308:     #{headers := RespHeaders, socket := Socket} = ws_handshake(#{extra_headers => ReqHeaders}),
  309:     %% get_websocket/0 does not support more than one open connection
  310:     gen_tcp:close(Socket),
  311:     wait_for_no_ranch_connections(10),
  312:     RespProtocol = proplists:get_value("Sec-Websocket-Protocol", RespHeaders),
  313:     case RespProtocol of
  314:         ExpectedProtocol ->
  315:             ok;
  316:         _ ->
  317:             Info2 = Info#{response_headers => RespHeaders,
  318:                           response_protocol => RespProtocol},
  319:             ct:fail(Info2)
  320:     end.