1 module hunt.http.codec.websocket.encode;
2 
3 import hunt.http.Exceptions;
4 import hunt.http.WebSocketFrame;
5 import hunt.http.codec.websocket.model.CloseInfo;
6 import hunt.http.codec.websocket.model.Extension;
7 import hunt.http.WebSocketCommon;
8 import hunt.http.WebSocketPolicy;
9 
10 import hunt.collection;
11 import hunt.Exceptions;
12 import hunt.text.Common;
13 import hunt.util.StringBuilder;
14 
15 /**
16  * Generating a frame in WebSocket land.
17  * <p>
18  * <pre>
19  *    0                   1                   2                   3
20  *    0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
21  *   +-+-+-+-+-------+-+-------------+-------------------------------+
22  *   |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
23  *   |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
24  *   |N|V|V|V|       |S|             |   (if payload len==126/127)   |
25  *   | |1|2|3|       |K|             |                               |
26  *   +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
27  *   |     Extended payload length continued, if payload len == 127  |
28  *   + - - - - - - - - - - - - - - - +-------------------------------+
29  *   |                               |Masking-key, if MASK set to 1  |
30  *   +-------------------------------+-------------------------------+
31  *   | Masking-key (continued)       |          Payload Data         |
32  *   +-------------------------------- - - - - - - - - - - - - - - - +
33  *   :                     Payload Data continued ...                :
34  *   + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
35  *   |                     Payload Data continued ...                |
36  *   +---------------------------------------------------------------+
37  * </pre>
38  */
39 class Generator {
40     /**
41      * The overhead (maximum) for a framing header. Assuming a maximum sized payload with masking key.
42      */
43     enum int MAX_HEADER_LENGTH = 28;
44 
45     private WebSocketBehavior behavior;
46     private bool validating;
47     private bool readOnly;
48 
49     /**
50      * Are any flags in use
51      * <p>
52      * <p>
53      * <pre>
54      *   0100_0000 (0x40) = rsv1
55      *   0010_0000 (0x20) = rsv2
56      *   0001_0000 (0x10) = rsv3
57      * </pre>
58      */
59     private byte flagsInUse = 0x00;
60 
61     /**
62      * Construct Generator with provided policy and bufferPool
63      *
64      * @param policy the policy to use
65      */
66     this(WebSocketPolicy policy) {
67         this(policy, true, false);
68     }
69 
70     /**
71      * Construct Generator with provided policy and bufferPool
72      *
73      * @param policy     the policy to use
74      * @param validating true to enable RFC frame validation
75      */
76     this(WebSocketPolicy policy, bool validating) {
77         this(policy, validating, false);
78     }
79 
80     /**
81      * Construct Generator with provided policy and bufferPool
82      *
83      * @param policy     the policy to use
84      * @param validating true to enable RFC frame validation
85      * @param readOnly   true if generator is to treat frames as read-only and not modify them. Useful for debugging purposes, but not generally for runtime use.
86      */
87     this(WebSocketPolicy policy, bool validating, bool readOnly) {
88         this.behavior = policy.getBehavior();
89         this.validating = validating;
90         this.readOnly = readOnly;
91     }
92 
93     void assertFrameValid(WebSocketFrame frame) {
94         if (!validating) {
95             return;
96         }
97 
98         /*
99          * RFC 6455 Section 5.2
100          * 
101          * MUST be 0 unless an extension is negotiated that defines meanings for non-zero values. If a nonzero value is received and none of the negotiated
102          * extensions defines the meaning of such a nonzero value, the receiving endpoint MUST _Fail the WebSocket Connection_.
103          */
104         if (frame.isRsv1() && !isRsv1InUse()) {
105             throw new ProtocolException("RSV1 not allowed to be set");
106         }
107 
108         if (frame.isRsv2() && !isRsv2InUse()) {
109             throw new ProtocolException("RSV2 not allowed to be set");
110         }
111 
112         if (frame.isRsv3() && !isRsv3InUse()) {
113             throw new ProtocolException("RSV3 not allowed to be set");
114         }
115 
116         if (OpCode.isControlFrame(frame.getOpCode())) {
117             /*
118              * RFC 6455 Section 5.5
119              * 
120              * All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
121              */
122             if (frame.getPayloadLength() > 125) {
123                 throw new ProtocolException("Invalid control frame payload length");
124             }
125 
126             if (!frame.isFin()) {
127                 throw new ProtocolException("Control Frames must be FIN=true");
128             }
129 
130             /*
131              * RFC 6455 Section 5.5.1
132              * 
133              * close frame payload is specially formatted which is checked in CloseInfo
134              */
135             if (frame.getOpCode() == OpCode.CLOSE) {
136 
137                 ByteBuffer payload = frame.getPayload();
138                 if (payload !is null) {
139                     new CloseInfo(payload, true);
140                 }
141             }
142         }
143     }
144 
145     void configureFromExtensions(Extension[] exts) {
146         // default
147         flagsInUse = 0x00;
148 
149         // configure from list of extensions in use
150         foreach (Extension ext ; exts) {
151             if (ext.isRsv1User()) {
152                 flagsInUse = cast(byte) (flagsInUse | 0x40);
153             }
154             if (ext.isRsv2User()) {
155                 flagsInUse = cast(byte) (flagsInUse | 0x20);
156             }
157             if (ext.isRsv3User()) {
158                 flagsInUse = cast(byte) (flagsInUse | 0x10);
159             }
160         }
161     }
162 
163     ByteBuffer generateHeaderBytes(WebSocketFrame frame) {
164         ByteBuffer buffer = BufferUtils.allocate(MAX_HEADER_LENGTH);
165         generateHeaderBytes(frame, buffer);
166         return buffer;
167     }
168 
169     void generateHeaderBytes(WebSocketFrame frame, ByteBuffer buffer) {
170         int p = BufferUtils.flipToFill(buffer);
171 
172         // we need a framing header
173         assertFrameValid(frame);
174 
175         /*
176          * start the generation process
177          */
178         byte b = 0x00;
179 
180         // Setup fin thru opcode
181         if (frame.isFin()) {
182             b |= 0x80; // 1000_0000
183         }
184 
185         // Set the flags
186         if (frame.isRsv1()) {
187             b |= 0x40; // 0100_0000
188         }
189         if (frame.isRsv2()) {
190             b |= 0x20; // 0010_0000
191         }
192         if (frame.isRsv3()) {
193             b |= 0x10; // 0001_0000
194         }
195 
196         // NOTE: using .getOpCode() here, not .getType().getOpCode() for testing reasons
197         byte opcode = frame.getOpCode();
198 
199         if (frame.getOpCode() == OpCode.CONTINUATION) {
200             // Continuations are not the same OPCODE
201             opcode = OpCode.CONTINUATION;
202         }
203 
204         b |= opcode & 0x0F;
205 
206         buffer.put(b);
207 
208         // is masked
209         b = (frame.isMasked() ? cast(byte) 0x80 : cast(byte) 0x00);
210 
211         // payload lengths
212         int payloadLength = frame.getPayloadLength();
213 
214         /*
215          * if length is over 65535 then its a 7 + 64 bit length
216          */
217         if (payloadLength > 0xFF_FF) {
218             // we have a 64 bit length
219             b |= 0x7F;
220             buffer.put(b); // indicate 8 byte length
221             buffer.put(cast(byte) 0); //
222             buffer.put(cast(byte) 0); // anything over an
223             buffer.put(cast(byte) 0); // int is just
224             buffer.put(cast(byte) 0); // insane!
225             buffer.put(cast(byte) ((payloadLength >> 24) & 0xFF));
226             buffer.put(cast(byte) ((payloadLength >> 16) & 0xFF));
227             buffer.put(cast(byte) ((payloadLength >> 8) & 0xFF));
228             buffer.put(cast(byte) (payloadLength & 0xFF));
229         }
230         /*
231          * if payload is greater that 126 we have a 7 + 16 bit length
232          */
233         else if (payloadLength >= 0x7E) {
234             b |= 0x7E;
235             buffer.put(b); // indicate 2 byte length
236             buffer.put(cast(byte) (payloadLength >> 8));
237             buffer.put(cast(byte) (payloadLength & 0xFF));
238         }
239         /*
240          * we have a 7 bit length
241          */
242         else {
243             b |= (payloadLength & 0x7F);
244             buffer.put(b);
245         }
246 
247         // masking key
248         if (frame.isMasked() && !readOnly) {
249             byte[] mask = frame.getMask();
250             buffer.put(mask);
251             int maskInt = 0;
252             foreach (byte maskByte ; mask)
253                 maskInt = (maskInt << 8) + (maskByte & 0xFF);
254 
255             // perform data masking here
256             ByteBuffer payload = frame.getPayload();
257             if ((payload !is null) && (payload.remaining() > 0)) {
258                 int maskOffset = 0;
259                 int start = payload.position();
260                 int end = payload.limit();
261                 int remaining;
262                 while ((remaining = end - start) > 0) {
263                     if (remaining >= 4) {
264                         payload.putInt(start, payload.getInt(start) ^ maskInt);
265                         start += 4;
266                     } else {
267                         payload.put(start, cast(byte) (payload.get(start) ^ mask[maskOffset & 3]));
268                         ++start;
269                         ++maskOffset;
270                     }
271                 }
272             }
273         }
274 
275         BufferUtils.flipToFlush(buffer, p);
276     }
277 
278     /**
279      * Generate the whole frame (header + payload copy) into a single ByteBuffer.
280      * <p>
281      * Note: This is slow, moves lots of memory around. Only use this if you must (such as in unit testing).
282      *
283      * @param frame the frame to generate
284      * @param buf   the buffer to output the generated frame to
285      */
286     void generateWholeFrame(WebSocketFrame frame, ByteBuffer buf) {
287         buf.put(generateHeaderBytes(frame));
288         if (frame.hasPayload()) {
289             if (readOnly) {
290                 buf.put(frame.getPayload().slice());
291             } else {
292                 buf.put(frame.getPayload());
293             }
294         }
295     }
296 
297     void setRsv1InUse(bool rsv1InUse) {
298         if (readOnly) {
299             throw new RuntimeException("Not allowed to modify read-only frame");
300         }
301         flagsInUse = cast(byte) ((flagsInUse & 0xBF) | (rsv1InUse ? 0x40 : 0x00));
302     }
303 
304     void setRsv2InUse(bool rsv2InUse) {
305         if (readOnly) {
306             throw new RuntimeException("Not allowed to modify read-only frame");
307         }
308         flagsInUse = cast(byte) ((flagsInUse & 0xDF) | (rsv2InUse ? 0x20 : 0x00));
309     }
310 
311     void setRsv3InUse(bool rsv3InUse) {
312         if (readOnly) {
313             throw new RuntimeException("Not allowed to modify read-only frame");
314         }
315         flagsInUse = cast(byte) ((flagsInUse & 0xEF) | (rsv3InUse ? 0x10 : 0x00));
316     }
317 
318     bool isRsv1InUse() {
319         return (flagsInUse & 0x40) != 0;
320     }
321 
322     bool isRsv2InUse() {
323         return (flagsInUse & 0x20) != 0;
324     }
325 
326     bool isRsv3InUse() {
327         return (flagsInUse & 0x10) != 0;
328     }
329 
330     override
331     string toString() {
332         StringBuilder builder = new StringBuilder();
333         builder.append("Generator[");
334         builder.append(behavior);
335         if (validating) {
336             builder.append(",validating");
337         }
338         if (isRsv1InUse()) {
339             builder.append(",+rsv1");
340         }
341         if (isRsv2InUse()) {
342             builder.append(",+rsv2");
343         }
344         if (isRsv3InUse()) {
345             builder.append(",+rsv3");
346         }
347         builder.append("]");
348         return builder.toString();
349     }
350 }