add websock functionality
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
from django.db.models.signals import post_save
|
||||
|
||||
class ApiConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
|
||||
18
one_trip_api/api/migrations/0005_list_updates.py
Normal file
18
one_trip_api/api/migrations/0005_list_updates.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 4.1.3 on 2022-11-29 16:15
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('api', '0004_remove_recipe_list_delete_ingredient'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='list',
|
||||
name='updates',
|
||||
field=models.BigIntegerField(default=0),
|
||||
),
|
||||
]
|
||||
@@ -26,6 +26,7 @@ class Homegroup(models.Model):
|
||||
class List(models.Model):
|
||||
# Foreign Key ListIngredient -> List [as ingredients]
|
||||
homegroup = models.OneToOneField(Homegroup, on_delete=models.CASCADE, primary_key=True)
|
||||
updates = models.BigIntegerField(default=0);
|
||||
|
||||
|
||||
class Recipe(models.Model):
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from rest_framework import serializers
|
||||
from api.models import *
|
||||
from users.serializers import UserSerializer
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
|
||||
channel_layer = get_channel_layer()
|
||||
|
||||
class RecipeIngredientSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
@@ -29,9 +33,13 @@ class ListSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = List
|
||||
fields = ["homegroup", "ingredients"]
|
||||
fields = ["homegroup", "updates", "ingredients"]
|
||||
read_only_fields = ["homegroup"]
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
# async_to_sync(channel_layer.group_send)(f"group_{instance.homegroup.id}", {"type": "model_update"})
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
def get_ingredients(self, instance):
|
||||
ingredients = instance.ingredients.all().order_by("name")
|
||||
return ListIngredientSerializer(ingredients, many=True).data
|
||||
|
||||
@@ -3,7 +3,8 @@ from rest_framework import routers
|
||||
from api import views
|
||||
|
||||
router = routers.DefaultRouter()
|
||||
router.register(r'recipes', views.RecipeView)
|
||||
router.register(r'recipes', views.RecipeAllView)
|
||||
router.register(r'searchrecipes', views.RecipeSearchView)
|
||||
router.register(r'lists', views.ListView)
|
||||
router.register(r'recipeingredients', views.RecipeIngredientView)
|
||||
router.register(r'listingredients', views.ListIngredientView)
|
||||
|
||||
@@ -1,28 +1,69 @@
|
||||
from rest_framework import viewsets, mixins, views, status, permissions
|
||||
from rest_framework import viewsets, mixins, permissions, request, pagination, filters
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import Request
|
||||
from api.serializers import *
|
||||
from api.models import *
|
||||
|
||||
class RecipeView(viewsets.ModelViewSet):
|
||||
class HasHomegroup(permissions.BasePermission):
|
||||
def has_permission(self, request: Request, view):
|
||||
if not request.user.homegroup:
|
||||
return False
|
||||
|
||||
return super().has_permission(request, view)
|
||||
|
||||
class Pagination(pagination.PageNumberPagination):
|
||||
page_size = 4
|
||||
|
||||
class NoListModelViewset(mixins.CreateModelMixin, mixins.DestroyModelMixin, mixins.UpdateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
|
||||
pass
|
||||
|
||||
class RecipeSearchView(viewsets.ModelViewSet):
|
||||
serializer_class = RecipeSerializer
|
||||
permission_classes = [permissions.IsAuthenticated, HasHomegroup]
|
||||
queryset = Recipe.objects.all()
|
||||
filter_backends = [filters.SearchFilter]
|
||||
search_fields = ["name"]
|
||||
pagination_class = Pagination
|
||||
|
||||
def list(self, request: Request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(Recipe.objects.filter(homegroup=request.user.homegroup).order_by("name"));
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.serializer_class(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
class RecipeAllView(viewsets.ModelViewSet):
|
||||
serializer_class = RecipeSerializer
|
||||
permission_classes = [permissions.IsAuthenticated, HasHomegroup]
|
||||
queryset = Recipe.objects.all()
|
||||
filter_backends = [filters.SearchFilter]
|
||||
search_fields = ["name"]
|
||||
|
||||
def list(self, request: Request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(Recipe.objects.filter(homegroup=request.user.homegroup).order_by("name"));
|
||||
serializer = self.serializer_class(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
class HomegroupView(viewsets.ModelViewSet):
|
||||
serializer_class = HomegroupSerializer
|
||||
queryset = Homegroup.objects.all()
|
||||
|
||||
class HomegroupInviteView(viewsets.ModelViewSet):
|
||||
class HomegroupInviteView(NoListModelViewset):
|
||||
serializer_class = InviteSerializer
|
||||
queryset = HomegroupInvite.objects.all()
|
||||
|
||||
class RecipeIngredientView(viewsets.ModelViewSet):
|
||||
class RecipeIngredientView(NoListModelViewset):
|
||||
serializer_class = RecipeIngredientSerializer
|
||||
queryset = RecipeIngredient.objects.all()
|
||||
|
||||
class ListIngredientView(viewsets.ModelViewSet):
|
||||
class ListIngredientView(NoListModelViewset):
|
||||
serializer_class = ListIngredientSerializer
|
||||
queryset = ListIngredient.objects.all()
|
||||
|
||||
class ListView(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
|
||||
class ListView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, viewsets.GenericViewSet):
|
||||
serializer_class = ListSerializer
|
||||
queryset = List.objects.all()
|
||||
@@ -10,6 +10,9 @@ https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/
|
||||
import os
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
import ws.routing
|
||||
|
||||
settings = 'one_trip_api.settings.dev'
|
||||
if os.getenv("DJANGO_RELEASE", False):
|
||||
@@ -17,5 +20,10 @@ if os.getenv("DJANGO_RELEASE", False):
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', settings)
|
||||
|
||||
print("ASGI Started")
|
||||
django_asgi_app = get_asgi_application()
|
||||
|
||||
application = get_asgi_application()
|
||||
application = ProtocolTypeRouter({
|
||||
"http": django_asgi_app,
|
||||
"websocket": AuthMiddlewareStack(URLRouter(ws.routing.websocket_urlpatterns))
|
||||
})
|
||||
|
||||
@@ -33,6 +33,8 @@ REST_FRAMEWORK = {
|
||||
INSTALLED_APPS = [
|
||||
'api',
|
||||
'users',
|
||||
'ws',
|
||||
'daphne',
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
@@ -76,6 +78,15 @@ TEMPLATES = [
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = 'one_trip_api.wsgi.application'
|
||||
ASGI_APPLICATION = 'one_trip_api.asgi.application'
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [("127.0.0.1", 6379)],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
|
||||
@@ -17,5 +17,5 @@ if os.getenv("DJANGO_RELEASE", False):
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', settings)
|
||||
|
||||
|
||||
print("WSGI Started")
|
||||
application = get_wsgi_application()
|
||||
|
||||
@@ -8,7 +8,7 @@ class ExemptCSRFMiddleware:
|
||||
|
||||
def __call__(self, request):
|
||||
|
||||
if request.path_info == "/auth/token":
|
||||
if request.path_info in ["/auth/token", "/auth/users/"]:
|
||||
setattr(request, '_dont_enforce_csrf_checks', True)
|
||||
|
||||
response = self.get_response(request)
|
||||
|
||||
0
one_trip_api/ws/__init__.py
Normal file
0
one_trip_api/ws/__init__.py
Normal file
3
one_trip_api/ws/admin.py
Normal file
3
one_trip_api/ws/admin.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
6
one_trip_api/ws/apps.py
Normal file
6
one_trip_api/ws/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class WsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'ws'
|
||||
51
one_trip_api/ws/consumers.py
Normal file
51
one_trip_api/ws/consumers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.generic.websocket import AsyncJsonWebsocketConsumer
|
||||
from rest_framework.authtoken.models import Token
|
||||
from api.models import Homegroup
|
||||
from users.models import User
|
||||
|
||||
class ChatConsumer(AsyncJsonWebsocketConsumer):
|
||||
async def connect(self):
|
||||
token_homegroup = await self.get_homegroup_by_token(self.scope["headers"])
|
||||
if token_homegroup is None:
|
||||
await self.disconnect(1)
|
||||
else:
|
||||
self.room_name = token_homegroup.id
|
||||
self.room_group_name = f"group_{self.room_name}"
|
||||
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
|
||||
await self.accept()
|
||||
|
||||
|
||||
async def receive_json(self, content, **kwargs):
|
||||
await self.channel_layer.group_send(
|
||||
self.room_group_name,
|
||||
content
|
||||
)
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
|
||||
|
||||
async def broadcast_update(self, event):
|
||||
print(event)
|
||||
await self.send_json(content={"type": "recommend_update", "hash": event["hash"]})
|
||||
|
||||
@database_sync_to_async
|
||||
def get_homegroup_by_token(self, headers):
|
||||
headers = self.scope["headers"]
|
||||
for pair in headers:
|
||||
if pair[0].decode("UTF-8") == "authorization":
|
||||
tokenType, tokenString = pair[1].decode("UTF-8").split()
|
||||
|
||||
queryset = Token.objects.filter(key=tokenString)
|
||||
if queryset.exists():
|
||||
return Token.objects.get(key=tokenString).user.homegroup
|
||||
else:
|
||||
return None
|
||||
|
||||
@database_sync_to_async
|
||||
def get_homegroup_by_id(self, group_id):
|
||||
queryset = Homegroup.objects.filter(id=group_id)
|
||||
if queryset.exists():
|
||||
return queryset.get()
|
||||
else:
|
||||
return None
|
||||
0
one_trip_api/ws/migrations/__init__.py
Normal file
0
one_trip_api/ws/migrations/__init__.py
Normal file
3
one_trip_api/ws/models.py
Normal file
3
one_trip_api/ws/models.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
7
one_trip_api/ws/routing.py
Normal file
7
one_trip_api/ws/routing.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from django.urls import re_path, path
|
||||
|
||||
from ws import consumers
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path('ws/', consumers.ChatConsumer.as_asgi(), name='room')
|
||||
]
|
||||
3
one_trip_api/ws/tests.py
Normal file
3
one_trip_api/ws/tests.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
4
one_trip_api/ws/views.py
Normal file
4
one_trip_api/ws/views.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from django.shortcuts import render
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Create your views here.
|
||||
Reference in New Issue
Block a user