From 7dd7abc09c90804034d663d60c1ccfa8f0b89a0b Mon Sep 17 00:00:00 2001 From: Alexander Laevens Date: Wed, 30 Nov 2022 16:02:42 -0700 Subject: [PATCH] Convert dart:io websocket to web_socket_channel Now send token as query parameter instead --- one_trip/lib/pages/list_page/list_page.dart | 51 ++++++++++----------- one_trip/pubspec.lock | 7 +++ one_trip/pubspec.yaml | 1 + one_trip_api/one_trip_api/asgi.py | 3 +- one_trip_api/ws/consumers.py | 30 ++++++------ 5 files changed, 49 insertions(+), 43 deletions(-) diff --git a/one_trip/lib/pages/list_page/list_page.dart b/one_trip/lib/pages/list_page/list_page.dart index c944800..3377412 100644 --- a/one_trip/lib/pages/list_page/list_page.dart +++ b/one_trip/lib/pages/list_page/list_page.dart @@ -1,6 +1,4 @@ import 'dart:convert'; -import 'dart:io'; - import 'package:flutter/material.dart'; import 'package:one_trip/api/auth.dart'; import 'package:one_trip/api/consts.dart'; @@ -10,6 +8,7 @@ import 'package:one_trip/api/models/user.dart'; import 'package:one_trip/pages/list_page/widgets/listrow.dart'; import 'package:one_trip/pages/list_page/widgets/search_recipes.dart'; import 'package:one_trip/widgets/text_entry_dialog.dart'; +import 'package:web_socket_channel/web_socket_channel.dart'; class ListPage extends StatefulWidget { const ListPage({super.key}); @@ -22,7 +21,7 @@ class _ListPageState extends State { ShoppingList? _list; late Future _isLoaded; User? _userInfo; - WebSocket? _ws; + WebSocketChannel? _wsChannel; Future _fetchList() async { User? userInfo = await User.getMe(); @@ -33,46 +32,47 @@ class _ListPageState extends State { } _list = await ShoppingList.get(userInfo.homegroup!); + _connectSocket(); return true; } void _connectSocket() async { String token = TokenSingleton().getToken(); - _ws = await WebSocket.connect("$baseWsURL/ws/", - headers: {"Authorization": "Token $token"}); + _wsChannel = WebSocketChannel.connect( + Uri.parse("$baseWsURL/ws/?authorization=$token")); + _wsChannel!.stream.listen( + (event) async { + Map json = jsonDecode(event); - if (_ws == null) { - return; - } + if (json.keys.contains("type") && json["type"] == "recommend_update") { + if (json["hash"] != _list.hashCode) { + ShoppingList? newList = await ShoppingList.get(_list!.homegroup); - _ws!.listen((event) async { - Map json = jsonDecode(event); - - if (json.keys.contains("type") && json["type"] == "recommend_update") { - if (json["hash"] != _list.hashCode) { - ShoppingList? newList = await ShoppingList.get(_list!.homegroup); - - if (newList != null) { - setState(() { - _list = newList; - }); + if (newList != null) { + setState(() { + _list = newList; + }); + } } } - } - }); + }, + onError: (error) => print("Websocket error: $error"), + onDone: () => print("Websocket Done"), + ); } void _sendUpdate() async { - if (_ws == null) { + if (_wsChannel == null) { return; } - _ws!.add(jsonEncode({"type": "broadcast_update", "hash": _list.hashCode})); + _wsChannel!.sink + .add(jsonEncode({"type": "broadcast_update", "hash": _list.hashCode})); } @override void dispose() { - if (_ws != null) { - _ws!.close(); + if (_wsChannel != null) { + _wsChannel!.sink.close(); } super.dispose(); } @@ -81,7 +81,6 @@ class _ListPageState extends State { void initState() { super.initState(); _isLoaded = _fetchList(); - _connectSocket(); } @override diff --git a/one_trip/pubspec.lock b/one_trip/pubspec.lock index 3ab3426..a37bd18 100644 --- a/one_trip/pubspec.lock +++ b/one_trip/pubspec.lock @@ -392,6 +392,13 @@ packages: url: "https://pub.dartlang.org" source: hosted version: "2.1.2" + web_socket_channel: + dependency: "direct main" + description: + name: web_socket_channel + url: "https://pub.dartlang.org" + source: hosted + version: "2.2.0" xml: dependency: transitive description: diff --git a/one_trip/pubspec.yaml b/one_trip/pubspec.yaml index 5215636..9e5e99f 100644 --- a/one_trip/pubspec.yaml +++ b/one_trip/pubspec.yaml @@ -38,6 +38,7 @@ dependencies: flutter_svg_provider: ^1.0.3 image_picker: ^0.8.6 flutter_launcher_icons: ^0.11.0 + web_socket_channel: ^2.2.0 # The following adds the Cupertino Icons font to your application. diff --git a/one_trip_api/one_trip_api/asgi.py b/one_trip_api/one_trip_api/asgi.py index bfe38b5..d65d285 100644 --- a/one_trip_api/one_trip_api/asgi.py +++ b/one_trip_api/one_trip_api/asgi.py @@ -10,8 +10,6 @@ https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/ import os import django -django.setup() - from django.core.asgi import get_asgi_application from channels.routing import ProtocolTypeRouter, URLRouter from channels.auth import AuthMiddlewareStack @@ -22,6 +20,7 @@ if os.getenv("DJANGO_RELEASE", False): settings = 'one_trip_api.settings.release' os.environ.setdefault('DJANGO_SETTINGS_MODULE', settings) +django.setup() print("ASGI Started") django_asgi_app = get_asgi_application() diff --git a/one_trip_api/ws/consumers.py b/one_trip_api/ws/consumers.py index fdf3046..487a9c3 100644 --- a/one_trip_api/ws/consumers.py +++ b/one_trip_api/ws/consumers.py @@ -3,12 +3,17 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer from rest_framework.authtoken.models import Token from api.models import Homegroup from users.models import User +from urllib.parse import parse_qs class ChatConsumer(AsyncJsonWebsocketConsumer): async def connect(self): - token_homegroup = await self.get_homegroup_by_token(self.scope["headers"]) + query_params = parse_qs(self.scope["query_string"].decode()) + query_params.setdefault("authorization", [""]) + + token_homegroup = await self.get_homegroup_by_token(query_params["authorization"][0]) if token_homegroup is None: - await self.disconnect(1) + await self.accept() + await self.close(3000) else: self.room_name = token_homegroup.id self.room_group_name = f"group_{self.room_name}" @@ -23,24 +28,19 @@ class ChatConsumer(AsyncJsonWebsocketConsumer): ) async def disconnect(self, close_code): - await self.channel_layer.group_discard(self.room_group_name, self.channel_name) + if (close_code != 3000): + await self.channel_layer.group_discard(self.room_group_name, self.channel_name) async def broadcast_update(self, event): - print(event) await self.send_json(content={"type": "recommend_update", "hash": event["hash"]}) @database_sync_to_async - def get_homegroup_by_token(self, headers): - headers = self.scope["headers"] - for pair in headers: - if pair[0].decode("UTF-8") == "authorization": - tokenType, tokenString = pair[1].decode("UTF-8").split() - - queryset = Token.objects.filter(key=tokenString) - if queryset.exists(): - return Token.objects.get(key=tokenString).user.homegroup - else: - return None + def get_homegroup_by_token(self, tokenString): + queryset = Token.objects.filter(key=tokenString) + if queryset.exists(): + return Token.objects.get(key=tokenString).user.homegroup + else: + return None @database_sync_to_async def get_homegroup_by_id(self, group_id):