// ****************************************************************************
// Copyright(C) 2019 by Peter Birkholz, Dresden, Germany
// This file is part of the program MeasureTransferFunction.
// www.vocaltractlab.de
// ****************************************************************************

// ----------------------------------------------------------------------------
// Audio library for Windows systems.
// ----------------------------------------------------------------------------

#ifdef WIN32

#include "SoundLib.h"
#include "windows.h"
#include "mmsystem.h"
#include <iostream>

// For the volume control code:
#include <mmdeviceapi.h> 
#include <endpointvolume.h>
#include <audioclient.h>
#include <Functiondiscoverykeys_devpkey.h>

static WAVEFORMATEXTENSIBLE waveformatextensible;
static int samplingRate = 44100;        // Sampling rate; it will be overwritten in initSound().
static HWAVEOUT hWaveOut = NULL;        // Handle for output device.
static HWAVEIN  hWaveIn = NULL;         // Handle for input device.
static WAVEHDR waveOutHdr;
static WAVEHDR waveInHdr;
static bool isPlaying = false;
static bool isRecording = false;
static bool playingInitialized = false;
static bool recordingInitialized = false;

void initWaveformInputDevice(WAVEFORMATEXTENSIBLE format, int deviceId);
void initWaveformOutputDevice(WAVEFORMATEXTENSIBLE format, int deviceId);


// ****************************************************************************
/// Enumerates the audio endpoints (either input or output) using the Windows 
/// Core API. For these endpoints, the volume can be set with 
/// winSetAudioEndpointVolume().
// ****************************************************************************

bool enumerateAudioEndpoints(vector<string> &endpointNames, vector<string> &endpointIds, bool outputDevices)
{
  HRESULT hr = S_OK;
  IMMDeviceEnumerator *pEnumerator = NULL;
  IMMDeviceCollection *pCollection = NULL;
  IMMDevice *pEndpoint = NULL;
  IPropertyStore *pProps = NULL;
  LPWSTR pwszID = NULL;

  // Clear the result lists.

  endpointNames.clear();
  endpointIds.clear();

  // ****************************************************************
  // Create a single uninitialized object of the class associated 
  // with MMDeviceEnumerator.
  // ****************************************************************

  hr = CoCreateInstance(__uuidof(MMDeviceEnumerator), NULL, CLSCTX_ALL,
    __uuidof(IMMDeviceEnumerator), (void**)&pEnumerator);

  if (FAILED(hr))
  {
    printf("CoCreateInstance() failed!\n");
    return false;
  }

  // ****************************************************************
  // Generate a collection of audio endpoint devices for both
  // recording and playback.
  // ****************************************************************

  EDataFlow dataFlow = eCapture;
  if (outputDevices)
  {
    dataFlow = eRender;
  }

  hr = pEnumerator->EnumAudioEndpoints(dataFlow, DEVICE_STATE_ACTIVE, &pCollection);

  if (FAILED(hr))
  {
    printf("IMMDeviceEnumerator::EnumAudioEndpoints() failed!\n");
    pEnumerator->Release();
    return false;
  }

  // ****************************************************************
  // Get the number of endpoint devices.
  // ****************************************************************

  UINT  count;
  hr = pCollection->GetCount(&count);

  if (FAILED(hr))
  {
    printf("IMMDeviceCollection::GetCount() failed!\n");
    pEnumerator->Release();
    pCollection->Release();
    return false;
  }

  // ****************************************************************
  // Each loop prints the name of an endpoint device.
  // ****************************************************************

  int i;
  wstring wst;
  string st;

  for (i = 0; i < (int)count; i++)
  {
    // Get pointer to endpoint number i.
    pCollection->Item(i, &pEndpoint);

    // Get the endpoint ID string.
    pEndpoint->GetId(&pwszID);

    pEndpoint->OpenPropertyStore(STGM_READ, &pProps);

    PROPVARIANT varName;
    // Initialize container for property value.
    PropVariantInit(&varName);

    // Get the endpoint's friendly-name property.
    pProps->GetValue(PKEY_Device_FriendlyName, &varName);

    // Add the endpoint name and ID to the lists.

    wst = wstring(varName.pwszVal);
    st.assign(wst.begin(), wst.end());  // Convert wstring to string
    endpointNames.push_back(st);

    wst = wstring(pwszID);
    st.assign(wst.begin(), wst.end());  // Convert wstring to string
    endpointIds.push_back(st);

    // Free memory.

    CoTaskMemFree(pwszID);
    pwszID = NULL;
    PropVariantClear(&varName);

    if (pProps != NULL)
    {
      pProps->Release();
      pProps = NULL;
    }

    if (pEndpoint != NULL)
    {
      pEndpoint->Release();
      pEndpoint = NULL;
    }
  }

  // ****************************************************************
  // Free memory for enumerator and collection.
  // ****************************************************************

  if (pEnumerator == NULL)
  {
    pEnumerator->Release();
  }

  if (pCollection == NULL)
  {
    pCollection->Release();
  }
  
  return true;
}


// ****************************************************************************
/// Sets the volume (between 0.0 and 1.0) for the audio endpoint with the 
/// given ID. The enpoint names associated with the indices can be obtained
/// with enumerateAudioEnpoints().
// ****************************************************************************

bool setAudioEndpointVolume(string endpointId, float volume)
{
  IMMDeviceEnumerator* m_pEnumerator = NULL;
  IMMDevice* m_pDevice = NULL;
  IAudioEndpointVolume* m_AudioEndpointVolume = NULL;

  // ****************************************************************
  // Create a single uninitialized object of the class associated 
  // with MMDeviceEnumerator.
  // ****************************************************************

  HRESULT hr = CoCreateInstance(__uuidof(MMDeviceEnumerator), NULL, CLSCTX_ALL,
    __uuidof(IMMDeviceEnumerator), (void**)&m_pEnumerator);

  if (FAILED(hr))
  {
    printf("CoCreateInstance() failed!\n");
    return false;
  }

  // ****************************************************************
  // Retrieves the audio endpoint for the specified ID string.
  // ****************************************************************

  // Convert entpointId into a wide string.
  wstring endpointIdWide(endpointId.begin(), endpointId.end());

  hr = m_pEnumerator->GetDevice(endpointIdWide.c_str(), &m_pDevice);

  if (FAILED(hr))
  {
    printf("IMMDeviceEnumerator::GetDevice() failed!\n");
    m_pEnumerator->Release();
    return false;
  }  

  // ****************************************************************
  // Creates a COM object for the interface IAudioEndpointVolume
  // by calling Activate() on the device object (IMMDevice).
  // ****************************************************************

  hr = m_pDevice->Activate(__uuidof(IAudioEndpointVolume), CLSCTX_ALL,
    NULL, (void**)&m_AudioEndpointVolume);

  if (FAILED(hr))
  {
    printf("IMMDevice::Activate() failed!\n");
    m_pEnumerator->Release();
    m_pDevice->Release();
    return false;
  }
  
  // ****************************************************************
  // Get information about the current volume level.
  // ****************************************************************

  float prevLevel;
  m_AudioEndpointVolume->GetMasterVolumeLevelScalar(&prevLevel);

//  printf("Current volume level of audio endpoint %S is %2.2f.\n", endpointId, prevLevel);

  // ****************************************************************
  // Set the master volume level of the endpoint device.
  // 0.0 <= volume <= 1.0
  // ****************************************************************

  m_AudioEndpointVolume->SetMasterVolumeLevelScalar(volume, &GUID_NULL);
  printf("New volume level set to %2.2f.\n", volume);

  // ****************************************************************

  m_pEnumerator->Release();
  m_pDevice->Release();
  m_AudioEndpointVolume->Release();

  return true;
}


// ****************************************************************************
/// Finds the index of the endpoint in the audio endpoint list that corresponds
/// to the given device name. Therefore, the endpoint name in the list must be
/// contained in the given device name.
/// If the name is not found, -1 is returned.
// ****************************************************************************

int getEndpointIndexFor(string deviceName, vector<string> &endpointNames)
{
  int index = -1;
  int i;

  for (i = 0; i < (int)endpointNames.size(); i++)
  {
    if (deviceName.find(endpointNames[i]) != wstring::npos)
    {
      index = i;
    }
  }
  
  return index;
}


// ****************************************************************************
// Initialization.
// Set inputDeviceId and outputDeviceId to WAVE_MAPPER to use the default
// devices.
// ****************************************************************************

void initSound(int SR, int inputDeviceId, int outputDeviceId)
{
  samplingRate = SR;
  
  // Prepare audio device for output (24 Bit, mono, samplingRate).
  waveformatextensible.Format.wFormatTag = WAVE_FORMAT_EXTENSIBLE;
  waveformatextensible.Format.nChannels = 1;
  waveformatextensible.Format.nSamplesPerSec = samplingRate;
  waveformatextensible.Format.nAvgBytesPerSec = samplingRate * 4;
  waveformatextensible.Format.nBlockAlign = 4;
  waveformatextensible.Format.wBitsPerSample = 32;
  // cbSize specifies how many bytes of additional format data are appended to the WAVEFORMATEX structure:
  waveformatextensible.Format.cbSize = sizeof(WAVEFORMATEXTENSIBLE) - sizeof(WAVEFORMATEX);
  waveformatextensible.Samples.wValidBitsPerSample = 32;

  waveformatextensible.dwChannelMask = SPEAKER_FRONT_LEFT;
  waveformatextensible.SubFormat = KSDATAFORMAT_SUBTYPE_IEEE_FLOAT;

  // ****************************************************************

  initWaveformInputDevice(waveformatextensible, inputDeviceId);
  initWaveformOutputDevice(waveformatextensible, outputDeviceId);
}


// ****************************************************************************
// Init the input device.
// ****************************************************************************

void initWaveformInputDevice(WAVEFORMATEXTENSIBLE format, int deviceId)
{
	MMRESULT mmresult;
  if (recordingInitialized) { waveInClose(hWaveIn); }

  printf("Trying to initialize the audio input device with %d Hz and 24 bit...\n", format.Format.nSamplesPerSec);

  if ((mmresult = waveInOpen(&hWaveIn, deviceId, &format.Format, 0, 0, CALLBACK_NULL)) != MMSYSERR_NOERROR)
  {
    // Opening the waveform input device failed.
    recordingInitialized = false;
    printf("Error: Opening the audio input device with ID=%d for recording failed with error %d !\n", deviceId, mmresult);
  }
  else
  {
    printf("Initialization succeeded.\n");
    recordingInitialized = true;
  }
}

// ****************************************************************************
// Init the output.
// ****************************************************************************

void initWaveformOutputDevice(WAVEFORMATEXTENSIBLE format, int deviceId)
{
  MMRESULT mmresult;
  if (playingInitialized) { waveOutClose(hWaveOut); }

  printf("Trying to initialize the audio output device with %d Hz and 24 bit...\n", format.Format.nSamplesPerSec);

  if ((mmresult = waveOutOpen(&hWaveOut, deviceId, &format.Format, 0, 0, CALLBACK_NULL)) != MMSYSERR_NOERROR)
  {
    // Opening the waveform output device failed.
    playingInitialized = false;
	  printf("Error: Opening the audio output device with ID=%d for playback failed with error %d !\n", deviceId, mmresult);
  }
  else
  {
    printf("Initialization succeeded.\n");
    playingInitialized = true;
  }
}

// ****************************************************************************
// Close the input and output devices.
// ****************************************************************************

void exitSound()
{
  if (playingInitialized) { waveOutClose(hWaveOut); }
  if (recordingInitialized) { waveInClose(hWaveIn); }
}

// ****************************************************************************
// Startet das Abspielen in einer Endlosschleife.
// ****************************************************************************

bool waveStartPlaying(float *data, int numSamples)
{
  if (playingInitialized == false) { return false; }

  waveOutHdr.lpData = (LPSTR)data;
  waveOutHdr.dwBufferLength = numSamples*4;   // Angabe in Bytes
  waveOutHdr.dwBytesRecorded = 0;
  waveOutHdr.dwUser = 0;
  waveOutHdr.dwFlags = 0;
  waveOutHdr.dwLoops = 1;
  waveOutHdr.lpNext = NULL;
  waveOutHdr.reserved = 0;

  waveOutPrepareHeader(hWaveOut, &waveOutHdr, sizeof(WAVEHDR));

//  waveOutHdr.dwFlags|= WHDR_BEGINLOOP | WHDR_ENDLOOP;
  waveOutWrite(hWaveOut, &waveOutHdr, sizeof(WAVEHDR));
  isPlaying = true;
  return true;
}

// ****************************************************************************
// Beendet das Abspielen.
// ****************************************************************************

bool waveStopPlaying()
{
  if (playingInitialized == false) { return false; }

  waveOutReset(hWaveOut);     // Stoppt die Wiedergabe
  waveOutUnprepareHeader(hWaveOut, &waveOutHdr, sizeof(WAVEHDR));
  isPlaying = false;
  return true;
}

// ****************************************************************************
// Startet die Aufnahme in den Ringpuffer.
// ****************************************************************************

bool waveStartRecording(float *data, int numSamples)
{
  if (recordingInitialized == false) 
  { 
    return false; 
  }

  waveInHdr.lpData = (LPSTR)data;
  waveInHdr.dwBufferLength = numSamples*4;   // Angabe in Bytes
  waveInHdr.dwBytesRecorded = 0;
  waveInHdr.dwUser = 0;
  waveInHdr.dwFlags = 0;
  waveInHdr.dwLoops = 1;          // wird bei der Aufnahme ignoriert
  waveInHdr.lpNext = NULL;
  waveInHdr.reserved = 0;

  waveInPrepareHeader(hWaveIn, &waveInHdr, sizeof(WAVEHDR));
  waveInAddBuffer(hWaveIn, &waveInHdr, sizeof(WAVEHDR));
  waveInStart(hWaveIn);
  
  isRecording = true;

  return true;
}

// ****************************************************************************
// Stoppt die Aufnahme.
// ****************************************************************************

bool waveStopRecording()
{
  if (recordingInitialized == false) { return false; }
  waveInUnprepareHeader(hWaveIn, &waveInHdr, sizeof(WAVEHDR));
  waveInReset(hWaveIn);     // Stoppt die Aufnahme

  return true;
}

// ****************************************************************************

#endif

