EStiMo/connection/RDA.py
2023-06-14 14:25:52 +02:00

304 lines
10 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 16 15:00:23 2022
@author: s202442
"""
# needs socket and struct library
from socket import socket, AF_INET, SOCK_STREAM
from struct import unpack
import sys
import numpy as np
import threading
import queue
import time
"""Packets are received every 20 ms in the size that it fits the sampling rate
e.g.:
for 1000 Hz packet size will be 20, because 20*50=1000
for 2500 Hz packet size will be 50, because 50*50=2500
for 50 kHz it will be 1000, because 1000*50=50000
"""
def average(arr, n, mode='mean'):
if mode=='max':
end = n * int(len(arr)/n)
return np.max(arr[:end].reshape(-1, n), 1)
arr = arr.T
data_raw_new = np.zeros((arr.shape[0], int(arr.shape[1]/n)))
for i in range(arr.shape[0]):
a = arr[i]
data_raw_new[i,:] = a.reshape(-1, n).mean(1)
return data_raw_new.T
# Marker class for storing marker information
class Marker:
def __init__(self):
self.position = 0
self.points = 0
self.channel = -1
self.type = ""
self.description = ""
# Helper function for receiving whole message
def RecvData(socket, requestedSize):
returnStream = bytes()
while len(returnStream) < requestedSize:
databytes = socket.recv(requestedSize - len(returnStream))
if databytes == '':
raise RuntimeError
# print(databytes)
returnStream += databytes
return returnStream
# Helper function for splitting a raw array of
# zero terminated strings (C) into an array of python strings
def SplitString(raw):
stringlist = []
s = bytes()
for i in range(len(raw)):
if raw[i] != 0: #'\x00':
s = s + raw[i].to_bytes(1, sys.byteorder)
else:
stringlist.append(s.decode())
s = bytes()
return stringlist
# Helper function for extracting eeg properties from a raw data array
# read from tcpip socket
def GetProperties(rawdata):
# Extract numerical data
(channelCount, samplingInterval) = unpack('<Ld', rawdata[:12])
# Extract resolutions
resolutions = []
for c in range(channelCount):
index = 12 + c * 8
restuple = unpack('<d', rawdata[index:index+8])
resolutions.append(restuple[0])
# Extract channel names
print(type(rawdata))
channelNames = SplitString(rawdata[12 + 8 * channelCount:])
print(rawdata[12 + 8 * channelCount:])
print('-----')
print(channelNames)
return (channelCount, samplingInterval, resolutions, channelNames)
# Helper function for extracting eeg and marker data from a raw data array
# read from tcpip socket
def GetData(rawdata, channelCount):
# Extract numerical data
(block, points, markerCount) = unpack('<LLL', rawdata[:12])
# Extract eeg data as array of floats
data = []
for i in range(points * channelCount):
index = 12 + 4 * i
value = unpack('<f', rawdata[index:index+4])
data.append(value[0])
# Extract markers
markers = []
index = 12 + 4 * points * channelCount
for m in range(markerCount):
markersize = unpack('<L', rawdata[index:index+4])
ma = Marker()
(ma.position, ma.points, ma.channel) = unpack('<LLl', rawdata[index+4:index+16])
typedesc = SplitString(rawdata[index+16:index+markersize[0]])
ma.type = typedesc[0]
ma.description = typedesc[1]
markers.append(ma)
index = index + markersize[0]
return (block, points, markerCount, data, markers)
def sampleLoop(obj):
# Get message header as raw array of chars
firstpackage=True
# databuf = np.empty((no.bufsiz,), dtype=np.uint8)
databuf = bytearray(b' ' * obj.bufsiz)
block=0
firstblock=0
oldblock=0
droppeds=0
timeout=False
data1s = []
while not obj.stop:
# Get message header as raw array of chars
rawhdr = RecvData(obj.sock, 24)
# Split array into usefull information id1 to id4 are constants
(id1, id2, id3, id4, msgsize, msgtype) = unpack('<llllLL', rawhdr)
# Get data part of message, which is of variable size
rawdata = RecvData(obj.sock, msgsize - 24)
if msgtype == 1:
# Start message, extract eeg properties and display them
(channelCount, samplingInterval, resolutions, channelNames) = GetProperties(rawdata)
# reset block counter
lastBlock = -1
print("Start")
print("Number of channels: " + str(channelCount))
print("Sampling interval: " + str(samplingInterval))
print("Resolutions: " + str(resolutions))
print("Channel Names: " + str(channelNames))
elif msgtype == 4:
# Data message, extract data and markers
(block, points, markerCount, data, markers) = GetData(rawdata, channelCount)
if block!=0:
ds=block-oldblock
if ds!=1:
droppeds += ds
print('Dropped %i blocks'%(ds,))
else:
firstblock=block
oldblock=block
# Check for overflow
if lastBlock != -1 and block > lastBlock + 1:
print("*** Overflow with " + str(block - lastBlock) + " datablocks ***" )
lastBlock = block
data1s.extend(data)
data1s = np.array(data1s)
# Print markers, if there are some in actual block
marker_sig = np.zeros([1, int(len(data1s)/channelCount)])
if markerCount > 0:
for m in range(markerCount):
print("Marker " + markers[m].description + " of type " + markers[m].type)
marker_sig[0][markers[m].position] = 1
t1 = time.time()
# Put data at the end of actual buffer
data_array = data1s.reshape([int(len(data1s)/channelCount), channelCount]) * np.array(resolutions)
data_array = np.vstack([data_array.T, marker_sig]).T #isn't that too slow?
if obj.avgPackets:
resampling_coef = int((len(data)/channelCount)/20)
data1=average(data_array, resampling_coef, 'mean')
data1[:,-1]=average(data_array[:,-1], resampling_coef, 'max')
obj.updateRingBuffer(data1,block)
else:
obj.updateRingBuffer(data_array,block)
data1s = []
elif msgtype == 3:
# Stop message, terminate program
print("Stop")
finish = True
obj.sock.close()
##############################################################################################
#
# Main RDA routine
#
##############################################################################################
class RDA():
def __init__(self,ip='127.0.0.1', port=51244, buffersize=2**10, sendqueue=False,
si=1/1000, ringbuffersize = 2**12, avgPackets=True):
# Create a tcpip socket
#con = socket(AF_INET, SOCK_STREAM)
# Connect to recorder host via 32Bit RDA-port
# adapt to your host, if recorder is not running on local machine
# change port to 51234 to connect to 16Bit RDA-port
#ip_client = "169.254.200.198 "#.96.224"
# ip_server = "169.254.252.66"
# port = 51244
# con.connect((ip_server, port))
self.sock=socket(AF_INET, # Internet
SOCK_STREAM) #UDP
#self.sock.bind((ip_server, port))
self.sock.connect((ip, port))
self.sock.settimeout(2.)
# s = socket(AF_INET, SOCK_DGRAM)
# s.bind((ip_client, port))
# s.settimeout(5)
# print(s.recvfrom(1024))
# con.settimeout(5)
# Flag for main loop
#finish = False
self.avgPackets = avgPackets
self.bufsiz=buffersize
self.ip=ip
self.port=port
self.sampidx=0
self.tstamp=None
self.tstamp0=None
self.queue=queue.Queue()
self.A=None
self.stop=False
self.idx=0
self.ringbufferinit=True
self.ringbuffersize=ringbuffersize
self.sendqueue=sendqueue
self.lock=threading.RLock()
self.si=si
def updateRingBuffer(self,data,i=None,tstamp=None):
if self.ringbufferinit:
self.ringbuffer=np.zeros((self.ringbuffersize ,data.shape[1]),dtype=np.float32)
self.ringbufferinit=False
ringbuf=self.ringbuffer
wlen=self.ringbuffersize
self.lock.acquire()
if (self.idx+data.shape[0])<=ringbuf.shape[0]:
ringbuf[self.idx:self.idx+data.shape[0],:]=data
self.idx+=data.shape[0]
else:
ringbuf[0:wlen-data.shape[0],:]=ringbuf[self.idx-wlen+data.shape[0]:self.idx,:]
self.idx=wlen
ringbuf[wlen-data.shape[0]:wlen,:]=data
self.datawindow=ringbuf[self.idx-wlen:self.idx]
if not i is None:
self.sampidx=i
if not tstamp is None:
self.tstamp=tstamp
self.lock.release()
def getBuffer(self,returnIdx=False):
self.lock.acquire()
try:
out=self.datawindow.copy()
except:
out=None
if returnIdx:
out=(out,self.sampidx)
self.lock.release()
return out
def start(self):
self.thread=threading.Thread(target=sampleLoop,args=(self,))
self.thread.start()
def stopit(self):
self.stop=True
self.thread.join()
try:
self.sock.close()
except:
pass