Convert dart:io websocket to web_socket_channel

Now send token as query parameter instead
This commit is contained in:
Alexander Laevens
2022-11-30 16:02:42 -07:00
parent 2e7a306279
commit 7dd7abc09c
5 changed files with 49 additions and 43 deletions

View File

@@ -1,6 +1,4 @@
import 'dart:convert'; import 'dart:convert';
import 'dart:io';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:one_trip/api/auth.dart'; import 'package:one_trip/api/auth.dart';
import 'package:one_trip/api/consts.dart'; import 'package:one_trip/api/consts.dart';
@@ -10,6 +8,7 @@ import 'package:one_trip/api/models/user.dart';
import 'package:one_trip/pages/list_page/widgets/listrow.dart'; import 'package:one_trip/pages/list_page/widgets/listrow.dart';
import 'package:one_trip/pages/list_page/widgets/search_recipes.dart'; import 'package:one_trip/pages/list_page/widgets/search_recipes.dart';
import 'package:one_trip/widgets/text_entry_dialog.dart'; import 'package:one_trip/widgets/text_entry_dialog.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
class ListPage extends StatefulWidget { class ListPage extends StatefulWidget {
const ListPage({super.key}); const ListPage({super.key});
@@ -22,7 +21,7 @@ class _ListPageState extends State<ListPage> {
ShoppingList? _list; ShoppingList? _list;
late Future<bool> _isLoaded; late Future<bool> _isLoaded;
User? _userInfo; User? _userInfo;
WebSocket? _ws; WebSocketChannel? _wsChannel;
Future<bool> _fetchList() async { Future<bool> _fetchList() async {
User? userInfo = await User.getMe(); User? userInfo = await User.getMe();
@@ -33,19 +32,16 @@ class _ListPageState extends State<ListPage> {
} }
_list = await ShoppingList.get(userInfo.homegroup!); _list = await ShoppingList.get(userInfo.homegroup!);
_connectSocket();
return true; return true;
} }
void _connectSocket() async { void _connectSocket() async {
String token = TokenSingleton().getToken(); String token = TokenSingleton().getToken();
_ws = await WebSocket.connect("$baseWsURL/ws/", _wsChannel = WebSocketChannel.connect(
headers: {"Authorization": "Token $token"}); Uri.parse("$baseWsURL/ws/?authorization=$token"));
_wsChannel!.stream.listen(
if (_ws == null) { (event) async {
return;
}
_ws!.listen((event) async {
Map<String, dynamic> json = jsonDecode(event); Map<String, dynamic> json = jsonDecode(event);
if (json.keys.contains("type") && json["type"] == "recommend_update") { if (json.keys.contains("type") && json["type"] == "recommend_update") {
@@ -59,20 +55,24 @@ class _ListPageState extends State<ListPage> {
} }
} }
} }
}); },
onError: (error) => print("Websocket error: $error"),
onDone: () => print("Websocket Done"),
);
} }
void _sendUpdate() async { void _sendUpdate() async {
if (_ws == null) { if (_wsChannel == null) {
return; return;
} }
_ws!.add(jsonEncode({"type": "broadcast_update", "hash": _list.hashCode})); _wsChannel!.sink
.add(jsonEncode({"type": "broadcast_update", "hash": _list.hashCode}));
} }
@override @override
void dispose() { void dispose() {
if (_ws != null) { if (_wsChannel != null) {
_ws!.close(); _wsChannel!.sink.close();
} }
super.dispose(); super.dispose();
} }
@@ -81,7 +81,6 @@ class _ListPageState extends State<ListPage> {
void initState() { void initState() {
super.initState(); super.initState();
_isLoaded = _fetchList(); _isLoaded = _fetchList();
_connectSocket();
} }
@override @override

View File

@@ -392,6 +392,13 @@ packages:
url: "https://pub.dartlang.org" url: "https://pub.dartlang.org"
source: hosted source: hosted
version: "2.1.2" version: "2.1.2"
web_socket_channel:
dependency: "direct main"
description:
name: web_socket_channel
url: "https://pub.dartlang.org"
source: hosted
version: "2.2.0"
xml: xml:
dependency: transitive dependency: transitive
description: description:

View File

@@ -38,6 +38,7 @@ dependencies:
flutter_svg_provider: ^1.0.3 flutter_svg_provider: ^1.0.3
image_picker: ^0.8.6 image_picker: ^0.8.6
flutter_launcher_icons: ^0.11.0 flutter_launcher_icons: ^0.11.0
web_socket_channel: ^2.2.0
# The following adds the Cupertino Icons font to your application. # The following adds the Cupertino Icons font to your application.

View File

@@ -10,8 +10,6 @@ https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/
import os import os
import django import django
django.setup()
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack from channels.auth import AuthMiddlewareStack
@@ -22,6 +20,7 @@ if os.getenv("DJANGO_RELEASE", False):
settings = 'one_trip_api.settings.release' settings = 'one_trip_api.settings.release'
os.environ.setdefault('DJANGO_SETTINGS_MODULE', settings) os.environ.setdefault('DJANGO_SETTINGS_MODULE', settings)
django.setup()
print("ASGI Started") print("ASGI Started")
django_asgi_app = get_asgi_application() django_asgi_app = get_asgi_application()

View File

@@ -3,12 +3,17 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from api.models import Homegroup from api.models import Homegroup
from users.models import User from users.models import User
from urllib.parse import parse_qs
class ChatConsumer(AsyncJsonWebsocketConsumer): class ChatConsumer(AsyncJsonWebsocketConsumer):
async def connect(self): async def connect(self):
token_homegroup = await self.get_homegroup_by_token(self.scope["headers"]) query_params = parse_qs(self.scope["query_string"].decode())
query_params.setdefault("authorization", [""])
token_homegroup = await self.get_homegroup_by_token(query_params["authorization"][0])
if token_homegroup is None: if token_homegroup is None:
await self.disconnect(1) await self.accept()
await self.close(3000)
else: else:
self.room_name = token_homegroup.id self.room_name = token_homegroup.id
self.room_group_name = f"group_{self.room_name}" self.room_group_name = f"group_{self.room_name}"
@@ -23,19 +28,14 @@ class ChatConsumer(AsyncJsonWebsocketConsumer):
) )
async def disconnect(self, close_code): async def disconnect(self, close_code):
if (close_code != 3000):
await self.channel_layer.group_discard(self.room_group_name, self.channel_name) await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
async def broadcast_update(self, event): async def broadcast_update(self, event):
print(event)
await self.send_json(content={"type": "recommend_update", "hash": event["hash"]}) await self.send_json(content={"type": "recommend_update", "hash": event["hash"]})
@database_sync_to_async @database_sync_to_async
def get_homegroup_by_token(self, headers): def get_homegroup_by_token(self, tokenString):
headers = self.scope["headers"]
for pair in headers:
if pair[0].decode("UTF-8") == "authorization":
tokenType, tokenString = pair[1].decode("UTF-8").split()
queryset = Token.objects.filter(key=tokenString) queryset = Token.objects.filter(key=tokenString)
if queryset.exists(): if queryset.exists():
return Token.objects.get(key=tokenString).user.homegroup return Token.objects.get(key=tokenString).user.homegroup