1 module hunt.http.codec.websocket.encode;
2 
3 import hunt.http.codec.websocket.exception;
4 import hunt.http.codec.websocket.frame.Frame;
5 import hunt.http.codec.websocket.model.CloseInfo;
6 import hunt.http.codec.websocket.model.Extension;
7 import hunt.http.codec.websocket.model.common;
8 import hunt.http.codec.websocket.stream.WebSocketPolicy;
9 
10 import hunt.container;
11 import hunt.lang.exception;
12 import hunt.string;
13 
14 /**
15  * Generating a frame in WebSocket land.
16  * <p>
17  * <pre>
18  *    0                   1                   2                   3
19  *    0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
20  *   +-+-+-+-+-------+-+-------------+-------------------------------+
21  *   |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
22  *   |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
23  *   |N|V|V|V|       |S|             |   (if payload len==126/127)   |
24  *   | |1|2|3|       |K|             |                               |
25  *   +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
26  *   |     Extended payload length continued, if payload len == 127  |
27  *   + - - - - - - - - - - - - - - - +-------------------------------+
28  *   |                               |Masking-key, if MASK set to 1  |
29  *   +-------------------------------+-------------------------------+
30  *   | Masking-key (continued)       |          Payload Data         |
31  *   +-------------------------------- - - - - - - - - - - - - - - - +
32  *   :                     Payload Data continued ...                :
33  *   + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
34  *   |                     Payload Data continued ...                |
35  *   +---------------------------------------------------------------+
36  * </pre>
37  */
38 class Generator {
39     /**
40      * The overhead (maximum) for a framing header. Assuming a maximum sized payload with masking key.
41      */
42     enum int MAX_HEADER_LENGTH = 28;
43 
44     private WebSocketBehavior behavior;
45     private bool validating;
46     private bool readOnly;
47 
48     /**
49      * Are any flags in use
50      * <p>
51      * <p>
52      * <pre>
53      *   0100_0000 (0x40) = rsv1
54      *   0010_0000 (0x20) = rsv2
55      *   0001_0000 (0x10) = rsv3
56      * </pre>
57      */
58     private byte flagsInUse = 0x00;
59 
60     /**
61      * Construct Generator with provided policy and bufferPool
62      *
63      * @param policy the policy to use
64      */
65     this(WebSocketPolicy policy) {
66         this(policy, true, false);
67     }
68 
69     /**
70      * Construct Generator with provided policy and bufferPool
71      *
72      * @param policy     the policy to use
73      * @param validating true to enable RFC frame validation
74      */
75     this(WebSocketPolicy policy, bool validating) {
76         this(policy, validating, false);
77     }
78 
79     /**
80      * Construct Generator with provided policy and bufferPool
81      *
82      * @param policy     the policy to use
83      * @param validating true to enable RFC frame validation
84      * @param readOnly   true if generator is to treat frames as read-only and not modify them. Useful for debugging purposes, but not generally for runtime use.
85      */
86     this(WebSocketPolicy policy, bool validating, bool readOnly) {
87         this.behavior = policy.getBehavior();
88         this.validating = validating;
89         this.readOnly = readOnly;
90     }
91 
92     void assertFrameValid(Frame frame) {
93         if (!validating) {
94             return;
95         }
96 
97         /*
98          * RFC 6455 Section 5.2
99          * 
100          * MUST be 0 unless an extension is negotiated that defines meanings for non-zero values. If a nonzero value is received and none of the negotiated
101          * extensions defines the meaning of such a nonzero value, the receiving endpoint MUST _Fail the WebSocket Connection_.
102          */
103         if (frame.isRsv1() && !isRsv1InUse()) {
104             throw new ProtocolException("RSV1 not allowed to be set");
105         }
106 
107         if (frame.isRsv2() && !isRsv2InUse()) {
108             throw new ProtocolException("RSV2 not allowed to be set");
109         }
110 
111         if (frame.isRsv3() && !isRsv3InUse()) {
112             throw new ProtocolException("RSV3 not allowed to be set");
113         }
114 
115         if (OpCode.isControlFrame(frame.getOpCode())) {
116             /*
117              * RFC 6455 Section 5.5
118              * 
119              * All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
120              */
121             if (frame.getPayloadLength() > 125) {
122                 throw new ProtocolException("Invalid control frame payload length");
123             }
124 
125             if (!frame.isFin()) {
126                 throw new ProtocolException("Control Frames must be FIN=true");
127             }
128 
129             /*
130              * RFC 6455 Section 5.5.1
131              * 
132              * close frame payload is specially formatted which is checked in CloseInfo
133              */
134             if (frame.getOpCode() == OpCode.CLOSE) {
135 
136                 ByteBuffer payload = frame.getPayload();
137                 if (payload !is null) {
138                     new CloseInfo(payload, true);
139                 }
140             }
141         }
142     }
143 
144     void configureFromExtensions(Extension[] exts) {
145         // default
146         flagsInUse = 0x00;
147 
148         // configure from list of extensions in use
149         foreach (Extension ext ; exts) {
150             if (ext.isRsv1User()) {
151                 flagsInUse = cast(byte) (flagsInUse | 0x40);
152             }
153             if (ext.isRsv2User()) {
154                 flagsInUse = cast(byte) (flagsInUse | 0x20);
155             }
156             if (ext.isRsv3User()) {
157                 flagsInUse = cast(byte) (flagsInUse | 0x10);
158             }
159         }
160     }
161 
162     ByteBuffer generateHeaderBytes(Frame frame) {
163         ByteBuffer buffer = BufferUtils.allocate(MAX_HEADER_LENGTH);
164         generateHeaderBytes(frame, buffer);
165         return buffer;
166     }
167 
168     void generateHeaderBytes(Frame frame, ByteBuffer buffer) {
169         int p = BufferUtils.flipToFill(buffer);
170 
171         // we need a framing header
172         assertFrameValid(frame);
173 
174         /*
175          * start the generation process
176          */
177         byte b = 0x00;
178 
179         // Setup fin thru opcode
180         if (frame.isFin()) {
181             b |= 0x80; // 1000_0000
182         }
183 
184         // Set the flags
185         if (frame.isRsv1()) {
186             b |= 0x40; // 0100_0000
187         }
188         if (frame.isRsv2()) {
189             b |= 0x20; // 0010_0000
190         }
191         if (frame.isRsv3()) {
192             b |= 0x10; // 0001_0000
193         }
194 
195         // NOTE: using .getOpCode() here, not .getType().getOpCode() for testing reasons
196         byte opcode = frame.getOpCode();
197 
198         if (frame.getOpCode() == OpCode.CONTINUATION) {
199             // Continuations are not the same OPCODE
200             opcode = OpCode.CONTINUATION;
201         }
202 
203         b |= opcode & 0x0F;
204 
205         buffer.put(b);
206 
207         // is masked
208         b = (frame.isMasked() ? cast(byte) 0x80 : cast(byte) 0x00);
209 
210         // payload lengths
211         int payloadLength = frame.getPayloadLength();
212 
213         /*
214          * if length is over 65535 then its a 7 + 64 bit length
215          */
216         if (payloadLength > 0xFF_FF) {
217             // we have a 64 bit length
218             b |= 0x7F;
219             buffer.put(b); // indicate 8 byte length
220             buffer.put(cast(byte) 0); //
221             buffer.put(cast(byte) 0); // anything over an
222             buffer.put(cast(byte) 0); // int is just
223             buffer.put(cast(byte) 0); // insane!
224             buffer.put(cast(byte) ((payloadLength >> 24) & 0xFF));
225             buffer.put(cast(byte) ((payloadLength >> 16) & 0xFF));
226             buffer.put(cast(byte) ((payloadLength >> 8) & 0xFF));
227             buffer.put(cast(byte) (payloadLength & 0xFF));
228         }
229         /*
230          * if payload is greater that 126 we have a 7 + 16 bit length
231          */
232         else if (payloadLength >= 0x7E) {
233             b |= 0x7E;
234             buffer.put(b); // indicate 2 byte length
235             buffer.put(cast(byte) (payloadLength >> 8));
236             buffer.put(cast(byte) (payloadLength & 0xFF));
237         }
238         /*
239          * we have a 7 bit length
240          */
241         else {
242             b |= (payloadLength & 0x7F);
243             buffer.put(b);
244         }
245 
246         // masking key
247         if (frame.isMasked() && !readOnly) {
248             byte[] mask = frame.getMask();
249             buffer.put(mask);
250             int maskInt = 0;
251             foreach (byte maskByte ; mask)
252                 maskInt = (maskInt << 8) + (maskByte & 0xFF);
253 
254             // perform data masking here
255             ByteBuffer payload = frame.getPayload();
256             if ((payload !is null) && (payload.remaining() > 0)) {
257                 int maskOffset = 0;
258                 int start = payload.position();
259                 int end = payload.limit();
260                 int remaining;
261                 while ((remaining = end - start) > 0) {
262                     if (remaining >= 4) {
263                         payload.putInt(start, payload.getInt(start) ^ maskInt);
264                         start += 4;
265                     } else {
266                         payload.put(start, cast(byte) (payload.get(start) ^ mask[maskOffset & 3]));
267                         ++start;
268                         ++maskOffset;
269                     }
270                 }
271             }
272         }
273 
274         BufferUtils.flipToFlush(buffer, p);
275     }
276 
277     /**
278      * Generate the whole frame (header + payload copy) into a single ByteBuffer.
279      * <p>
280      * Note: This is slow, moves lots of memory around. Only use this if you must (such as in unit testing).
281      *
282      * @param frame the frame to generate
283      * @param buf   the buffer to output the generated frame to
284      */
285     void generateWholeFrame(Frame frame, ByteBuffer buf) {
286         buf.put(generateHeaderBytes(frame));
287         if (frame.hasPayload()) {
288             if (readOnly) {
289                 buf.put(frame.getPayload().slice());
290             } else {
291                 buf.put(frame.getPayload());
292             }
293         }
294     }
295 
296     void setRsv1InUse(bool rsv1InUse) {
297         if (readOnly) {
298             throw new RuntimeException("Not allowed to modify read-only frame");
299         }
300         flagsInUse = cast(byte) ((flagsInUse & 0xBF) | (rsv1InUse ? 0x40 : 0x00));
301     }
302 
303     void setRsv2InUse(bool rsv2InUse) {
304         if (readOnly) {
305             throw new RuntimeException("Not allowed to modify read-only frame");
306         }
307         flagsInUse = cast(byte) ((flagsInUse & 0xDF) | (rsv2InUse ? 0x20 : 0x00));
308     }
309 
310     void setRsv3InUse(bool rsv3InUse) {
311         if (readOnly) {
312             throw new RuntimeException("Not allowed to modify read-only frame");
313         }
314         flagsInUse = cast(byte) ((flagsInUse & 0xEF) | (rsv3InUse ? 0x10 : 0x00));
315     }
316 
317     bool isRsv1InUse() {
318         return (flagsInUse & 0x40) != 0;
319     }
320 
321     bool isRsv2InUse() {
322         return (flagsInUse & 0x20) != 0;
323     }
324 
325     bool isRsv3InUse() {
326         return (flagsInUse & 0x10) != 0;
327     }
328 
329     override
330     string toString() {
331         StringBuilder builder = new StringBuilder();
332         builder.append("Generator[");
333         builder.append(behavior);
334         if (validating) {
335             builder.append(",validating");
336         }
337         if (isRsv1InUse()) {
338             builder.append(",+rsv1");
339         }
340         if (isRsv2InUse()) {
341             builder.append(",+rsv2");
342         }
343         if (isRsv3InUse()) {
344             builder.append(",+rsv3");
345         }
346         builder.append("]");
347         return builder.toString();
348     }
349 }