#!/usr/bin/python3
#
# Copyright 2011 Jared Boone
#
# This file is part of Project Ubertooth.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.

import signal
import sys
import threading
import numpy

from argparse import ArgumentParser

from PySide import QtGui, QtCore
import PySide.QtGui as QtWidgets
from PySide.QtCore import Qt, QPointF, QLineF

from specan import Ubertooth

DEFAULT_LOWER_FREQ = 2400
DEFAULT_UPPER_FREQ = 2483

# going much further causes the Ubertooth to stop responding :(

MIN_FREQ = 2300
MAX_FREQ = 2600

class SpecanThread(threading.Thread):
    def __init__(self, device, low_frequency, high_frequency, new_frame_callback, ubertooth_device=-1):
        threading.Thread.__init__(self)
        self.daemon = True

        self._device = device
        self._ubertooth_device = ubertooth_device
        self._low_frequency = low_frequency
        self._high_frequency = high_frequency
        self._new_frame_callback = new_frame_callback
        self._stopping = False
        self._stopped = False

    def run(self):
        frame_source = self._device.specan(self._low_frequency, self._high_frequency, ubertooth_device=self._ubertooth_device)
        for frequency_axis, rssi_values in frame_source:
            self._new_frame_callback(numpy.copy(frequency_axis), numpy.copy(rssi_values))
            if self._stopping:
                break

    def stop(self):
        self._stopping = True
        self.join(3.0)
        self._stopped = True


class RenderArea(QtWidgets.QWidget):
    def __init__(self, device, parent=None, ubertooth_device=-1, lower_freq=DEFAULT_LOWER_FREQ, upper_freq=DEFAULT_UPPER_FREQ):
        QtWidgets.QWidget.__init__(self, parent)

        self._graph = None
        self._reticle = None

        self._device = device
        self._frame = None
        self._persisted_frames = None
        self._persisted_frames_depth = 350
        self._path_max = None

        self._low_frequency = lower_freq * 1e6
        self._high_frequency = upper_freq * 1e6
        self._frequency_step = 1e6
        self._high_dbm = 0.0
        self._low_dbm = -100.0

        self._hide_markers = False
        self._mouse_x = None
        self._mouse_y = None
        self._mouse_x2 = None
        self._mouse_y2 = None

        self._clear_scheduled = False

        self._thread = SpecanThread(self._device,
                                    self._low_frequency,
                                    self._high_frequency,
                                    self._new_frame,
                                    ubertooth_device=ubertooth_device)
        self._thread.start()

    def schedule_clear(self):
        self._clear_scheduled = True

    def stop_thread(self):
        self._thread.stop()

    def _new_graph(self):
        self._graph = QtGui.QPixmap(self.width(), self.height())
        self._graph.fill(Qt.black)

    def _new_reticle(self):
        self._reticle = QtGui.QPixmap(self.width(), self.height())
        self._reticle.fill(Qt.transparent)

    def _new_persisted_frames(self, frequency_bins):
        self._persisted_frames = numpy.empty((self._persisted_frames_depth, frequency_bins))
        self._persisted_frames.fill(-128 + -54)
        self._persisted_frames_next_index = 0

    def minimumSizeHint(self):
        x_points = round((self._high_frequency - self._low_frequency) / self._frequency_step)
        y_points = round(self._high_dbm - self._low_dbm)
        return QtCore.QSize(x_points * 4, y_points * 1)

    def _new_frame(self, frequency_axis, rssi_values):

        self._frame = (frequency_axis, rssi_values)
        if self._persisted_frames is None:
            self._new_persisted_frames(len(frequency_axis))
        self._persisted_frames[self._persisted_frames_next_index] = rssi_values
        self._persisted_frames_next_index = (self._persisted_frames_next_index + 1) % self._persisted_frames.shape[0]
        self.update()

    def _draw_graph(self):
        if self._clear_scheduled:
            frequency_axis, _ = self._frame
            self._clear_scheduled = False
            self._new_graph()
            self._new_persisted_frames(len(frequency_axis))

        if self._graph is None:
            self._new_graph()
        elif self._graph.size() != self.size():
            self._new_graph()

        painter = QtGui.QPainter(self._graph)
        try:
            painter.setRenderHint(QtGui.QPainter.Antialiasing)
            painter.fillRect(0, 0, self._graph.width(), self._graph.height(), QtGui.QColor(0, 0, 0, 10))

            if self._frame:
                frequency_axis, rssi_values = self._frame

                path_now = QtGui.QPainterPath()
                path_max = QtGui.QPainterPath()

                bins = range(len(frequency_axis))
                x_axis = self._hz_to_x(frequency_axis)
                y_now = self._dbm_to_y(rssi_values)
                y_max = self._dbm_to_y(numpy.amax(self._persisted_frames, axis=0))

                path_now.moveTo(float(x_axis[0]), float(y_now[0]))
                for i in bins:
                    path_now.lineTo(float(x_axis[i]), float(y_now[i]))

                path_max.moveTo(float(x_axis[0]), float(y_max[0]))
                db_tmp = self._low_dbm
                max_max = None
                for i in bins:
                    path_max.lineTo(float(x_axis[i]), float(y_max[i]))
                    if self._y_to_dbm(y_max[i]) > db_tmp:
                        db_tmp = self._y_to_dbm(y_max[i])
                        max_max = i

                pen = QtGui.QPen()
                pen.setBrush(Qt.white)
                painter.setPen(pen)
                painter.drawPath(path_now)
                self._path_max = path_max
                if max_max is not None and not self._hide_markers:
                    pen.setBrush(Qt.red)
                    pen.setStyle(Qt.DotLine)
                    painter.setPen(pen)
                    painter.drawText(QPointF(x_axis[max_max] + 4, 30), '%.06f' % (self._x_to_hz(x_axis[max_max]) / 1e6))
                    painter.drawText(QPointF(30, y_max[max_max] - 4), '%d' % (self._y_to_dbm(y_max[max_max])))
                    painter.drawLine(QPointF(x_axis[max_max], 0), QPointF(x_axis[max_max], self.height()))
                    painter.drawLine(QPointF(0, y_max[max_max]), QPointF(self.width(), y_max[max_max]))
                    if self._mouse_x:
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 58), '(%.06f)' % ((self._x_to_hz(x_axis[max_max]) / 1e6) - (self._mouse_x / 1e6)))
                        pen.setBrush(Qt.yellow)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 44), '%.06f' % (self._mouse_x / 1e6))
                        painter.drawText(QPointF(54, self._dbm_to_y(self._mouse_y) - 4), '%d' % (self._mouse_y))
                        painter.drawLine(QPointF(self._hz_to_x(self._mouse_x), 0), QPointF(self._hz_to_x(self._mouse_x), self.height()))
                        painter.drawLine(QPointF(0, self._dbm_to_y(self._mouse_y)), QPointF(self.width(), self._dbm_to_y(self._mouse_y)))
                        if self._mouse_x2:
                            painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 118), '(%.06f)' % ((self._mouse_x / 1e6) - (self._mouse_x2 / 1e6)))
                    if self._mouse_x2:
                        pen.setBrush(Qt.red)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 102), '(%.06f)' % ((self._x_to_hz(x_axis[max_max]) / 1e6) - (self._mouse_x2 / 1e6)))
                        pen.setBrush(Qt.magenta)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 88), '%.06f' % (self._mouse_x2 / 1e6))
                        painter.drawText(QPointF(78, self._dbm_to_y(self._mouse_y2) - 4), '%d' % (self._mouse_y2))
                        painter.drawLine(QPointF(self._hz_to_x(self._mouse_x2), 0), QPointF(self._hz_to_x(self._mouse_x2), self.height()))
                        painter.drawLine(QPointF(0, self._dbm_to_y(self._mouse_y2)), QPointF(self.width(), self._dbm_to_y(self._mouse_y2)))
                        if self._mouse_x:
                            painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 74), '(%.06f)' % ((self._mouse_x2 / 1e6) - (self._mouse_x / 1e6)))
        finally:
            painter.end()

    def _draw_reticle(self):
        if self._reticle is None or (self._reticle.size() != self.size()):
            self._new_reticle()

            dbm_lines = [QLineF(self._hz_to_x(self._low_frequency), self._dbm_to_y(dbm),
                                self._hz_to_x(self._high_frequency), self._dbm_to_y(dbm))
                         for dbm in numpy.arange(self._low_dbm, self._high_dbm, 20.0)]
            dbm_labels = [(dbm, QPointF(self._hz_to_x(self._low_frequency) + 2, self._dbm_to_y(dbm) - 2))
                          for dbm in numpy.arange(self._low_dbm, self._high_dbm, 20.0)]

            frequency_lines = [QLineF(self._hz_to_x(frequency), self._dbm_to_y(self._high_dbm),
                                      self._hz_to_x(frequency), self._dbm_to_y(self._low_dbm))
                               for frequency in numpy.arange(self._low_frequency, self._high_frequency, self._frequency_step * 10.0)]
            frequency_labels = [(frequency, QPointF(self._hz_to_x(frequency) + 2, self._dbm_to_y(self._high_dbm) + 10))
                                for frequency in numpy.arange(self._low_frequency, self._high_frequency, self._frequency_step * 10.0)]

            painter = QtGui.QPainter(self._reticle)
            try:
                painter.setRenderHint(QtGui.QPainter.Antialiasing)

                painter.setPen(Qt.blue)

                # TODO: Removed to support old (<1.0) PySide API in Ubuntu 10.10
                # painter.drawLines(dbm_lines)
                for dbm_line in dbm_lines: painter.drawLine(dbm_line)
                # TODO: Removed to support old (<1.0) PySide API in Ubuntu 10.10
                # painter.drawLines(frequency_lines)
                for frequency_line in frequency_lines: painter.drawLine(frequency_line)

                painter.setPen(Qt.white)
                for dbm, point in dbm_labels:
                    painter.drawText(point, '%+.0f' % dbm)
                for frequency, point in frequency_labels:
                    painter.drawText(point, '%.0f' % (frequency / 1e6))

            finally:
                painter.end()

    def paintEvent(self, event):
        self._draw_graph()
        self._draw_reticle()

        painter = QtGui.QPainter(self)
        try:
            painter.setRenderHint(QtGui.QPainter.Antialiasing)
            painter.setPen(QtGui.QPen())
            painter.setBrush(QtGui.QBrush())

            if self._graph:
                painter.drawPixmap(0, 0, self._graph)

            if self._path_max:
                painter.setPen(Qt.green)
                painter.drawPath(self._path_max)

            painter.setOpacity(0.5)
            if self._reticle:
                painter.drawPixmap(0, 0, self._reticle)
        finally:
            painter.end()

    def _hz_to_x(self, frequency_hz):
        delta = frequency_hz - self._low_frequency
        range = self._high_frequency - self._low_frequency
        normalized = delta / range
        return normalized * self.width()

    def _x_to_hz(self, x):
        range = self._high_frequency - self._low_frequency
        tmp = x / self.width()
        delta = tmp * range
        return delta + self._low_frequency

    def _dbm_to_y(self, dbm):
        delta = self._high_dbm - dbm
        range = self._high_dbm - self._low_dbm
        normalized = delta / range
        return normalized * self.height()

    def _y_to_dbm(self, y):
        range = self._high_dbm - self._low_dbm
        tmp = y / self.height()
        delta = tmp * range
        return self._high_dbm - delta


class Window(QtWidgets.QWidget):
    def __init__(self, parent=None, ubertooth_device=-1, lower_freq=DEFAULT_LOWER_FREQ, upper_freq=DEFAULT_UPPER_FREQ):
        QtWidgets.QWidget.__init__(self, parent)

        self._device = self._open_device()

        self.render_area = RenderArea(self._device, ubertooth_device=ubertooth_device, lower_freq=lower_freq, upper_freq=upper_freq)

        main_layout = QtWidgets.QGridLayout()
        main_layout.setContentsMargins(0, 0, 0, 0)
        main_layout.addWidget(self.render_area, 0, 0)
        self.setLayout(main_layout)

        self.setWindowTitle("Ubertooth Spectrum Analyzer")

    def sizeHint(self):
        return QtCore.QSize(480, 160)

    def _open_device(self):
        return Ubertooth.Ubertooth()

    def closeEvent(self, event):
        self.render_area.stop_thread()
        self._device.close()
        event.accept()

    # handle mouse button clicks
    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.render_area._mouse_x = self.render_area._x_to_hz(float(event.x()))
            self.render_area._mouse_y = self.render_area._y_to_dbm(float(event.y()))
            self.render_area._hide_markers = False
        if event.button() == Qt.RightButton:
            self.render_area._mouse_x2 = self.render_area._x_to_hz(float(event.x()))
            self.render_area._mouse_y2 = self.render_area._y_to_dbm(float(event.y()))
            self.render_area._hide_markers = False
        if event.button() == Qt.MidButton:
            self.render_area._mouse_x = None
            self.render_area._mouse_y = None
            self.render_area._mouse_x2 = None
            self.render_area._mouse_y2 = None
            self.render_area._hide_markers = not self.render_area._hide_markers
        event.accept()
        return

    # handle key presses
    def keyPressEvent(self, event):
        try:
            key = chr(event.key()).upper()
            event.accept()
        except:
            print('Unknown key pressed: 0x%x' % event.key())
            event.ignore()
            return
        if key == 'H':
            print('Key                  Action\n')
            print(' <LEFT MOUSE>        Mark LEFT frequency / signal strength at pointer')
            print(' <RIGHT MOUSE>       Mark RIGHT frequency / signal strength at pointer')
            print(' <MIDDLE MOUSE>      Toggle visibility of frequency / signal strength markers')
            print(' C                   Clear graph')
            print(' H                   Print this HELP text')
            print(' M                   Simulate MIDDLE MOUSE click (for those with trackpads)')
            print(' Q                   Quit')
            return
        if key == 'M':
            self.render_area._mouse_x = None
            self.render_area._mouse_y = None
            self.render_area._mouse_x2 = None
            self.render_area._mouse_y2 = None
            self.render_area._hide_markers = not self.render_area._hide_markers
            return
        if key == 'C':
            self.render_area.schedule_clear()
            return
        if key == 'Q':
            print('Quit!')
            self.close()
            return
        print('Unsupported key pressed:', key)


def sigint_handler(*args):
    """Handler for the SIGINT signal."""
    QtWidgets.QApplication.quit()


def convert_wifi(channel):
    if channel < 1 or channel > 14:
        print("ERROR: channel " + str(channel) + " is not a valid wifi channel")
        raise ValueError()

    if channel == 14:
        return 2482
    else:
        return channel * 5 + 2407

def check_freq(freq):
    if freq < MIN_FREQ:
        print("ERROR: frequency of " + str(freq) + " MHz is below minimum frequency of " + str(MIN_FREQ))
        raise ValueError()
    if freq > MAX_FREQ:
        print("ERROR: frequency of " + str(freq) + " MHz is above maximum frequency of " + str(MAX_FREQ))
        raise ValueError()


def check_freq_pair(freq1, freq2):
    check_freq(freq1)
    check_freq(freq2)

    if freq1 > freq2:
        print("ERROR: lower frequency of " + str(freq1) + " MHz is above upper frequency of " + str(freq2) + " MHz")
        raise ValueError()

if __name__ == '__main__':
    signal.signal(signal.SIGINT, sigint_handler)

    parser = ArgumentParser()
    parser.add_argument("-U", type=int, dest="device",
                      help="set ubertooth device to use")
    parser.add_argument("-l", type=int, dest="lower_freq", help="lower bound for scan, in MHz (no less than " + str(MIN_FREQ) + ")")
    parser.add_argument("-u", type=int, dest="upper_freq", help="upper bound for scan, in MHz (no more than " + str(MAX_FREQ) + ")")
    parser.add_argument("--wifi", type=str, nargs='?', dest="wifi", metavar="channel(s)", help="display the spectrum for the wifi channels provided, either as a single number for one channel, or a range (e.g. 1-11) for two channels", const="1", default=False)
    parser.add_argument("--padding", type=int, dest="padding", help="padding on both ends when using --wifi, measured in MHz (default 10)", default=10)

    (options, extras) = parser.parse_known_args()

    ubertooth_device = options.device

    if ubertooth_device is None:
        ubertooth_device = -1

    lower_freq = options.lower_freq
    upper_freq = options.upper_freq

    if options.wifi:

        lower_channel = upper_channel = None

        parts = options.wifi.split("-")

        try:
            if len(parts) == 1:
                lower_channel = upper_channel = int(parts[0])
            elif len(parts) == 2:
                lower_channel = int(parts[0])
                upper_channel = int(parts[1])
            else:
                raise ValueError()
        except ValueError:
            print("ERROR: invalid channel range: " + options.wifi)
            sys.exit(1)

        try:
            lower_freq = convert_wifi(lower_channel) - options.padding
            upper_freq = convert_wifi(upper_channel) + options.padding
        except ValueError:
            sys.exit(1)
    else:
        if not lower_freq:
            lower_freq = DEFAULT_LOWER_FREQ
        if not upper_freq:
            upper_freq = DEFAULT_UPPER_FREQ

    try:
        check_freq_pair(lower_freq, upper_freq)
    except ValueError:
        sys.exit(1)

    app = QtWidgets.QApplication(sys.argv)
    window = Window(ubertooth_device=ubertooth_device, lower_freq=lower_freq, upper_freq=upper_freq)
    window.show()
    sys.exit(app.exec_())
