#! /usr/bin/python3

import os
import sys
import subprocess
import multiprocessing

def thread(task):
  valgrind,o,p,n,offset = task

  command = ['mceliece-test',o,p,'.',str(n),str(offset)]
  if valgrind:
    command = ['env','valgrind_multiplier=1','valgrind','-q','--max-stackframe=16777216','--error-exitcode=99']+command

  result = subprocess.run(command,stdout=subprocess.PIPE,stderr=subprocess.PIPE,universal_newlines=True)

  if result.returncode:
    return task,(f'nonzero return code {result.returncode}',result.stdout,result.stderr)

  if len(result.stderr) > 0: # bug in test program
    return task,(f'nonempty errors',result.stdout,result.stderr)

  saidsuccess = False
  saidimplementation = False
  saiddeclassify = False
  for line in result.stdout.splitlines():
    x = line.split()
    if x == ['all','tests','succeeded']: saidsuccess = True
    if len(x) >= 3 and x[2] == 'implementation': saidimplementation = True
    if x == ['valgrind','1','declassify','1']: saiddeclassify = True

  if not saidsuccess: # bug in test program
    return task,(f'did not say all tests succeeded',result.stdout,result.stderr)

  if not saidimplementation:
    return task,(f'CPU does not support implementation',result.stdout,result.stderr)

  if valgrind and not saiddeclassify:
    return task,(f'test does not support declassify',result.stdout,result.stderr)

  return task,(None,result.stdout,result.stderr)

def checkresult(task,result):
  valgrind,o,p,n,offset = task
  failure,out,err = result

  printtask = f'{o}/{p} impl {n} offset {offset}'
  printtask += ' dataflow' if valgrind else ' conventional'

  for desc,what in ('output',out),('error',err):
    for line in what.splitlines():
      print(f'{printtask} {desc}: {line}')

  if failure is None:
    print(f'{printtask} result: success')
  else:
    print(f'{printtask} result: {failure}')

  sys.stdout.flush()
  return failure is None

def doit(todo):
  todo = list(todo)
  print(f'tests to run: {len(todo)}')

  try:
    threads = len(os.sched_getaffinity(0))
  except:
    threads = multiprocessing.cpu_count()
  threads = os.getenv('THREADS',default=threads)
  threads = int(threads)
  threads = max(threads,1)
  threads = min(threads,len(todo))
  print(f'maximum threads allowed: {threads}')

  results = {}
  printpos = 0
  ok = True

  with multiprocessing.Pool(threads) as p:
    for task,result in p.imap_unordered(thread,todo,chunksize=1):
      results[task] = result
      while printpos < len(todo):
        task = todo[printpos]
        if task not in results: break
        if not checkresult(task,results[task]):
          ok = False
        printpos += 1

  assert printpos == len(todo)
  return ok

todo = [
  ('xof','shake256',-1),
 ('xof','shake256',0),
  ('xof','bitwrite16',-1),
 ('xof','bitwrite16',0),
  ('sort','int16',-1),
 ('sort','int16',0),
  ('sort','int32',-1),
 ('sort','int32',0),
  ('sort','int64',-1),
 ('sort','int64',0),
  ('kem','6960119',-1),
 ('kem','6960119',0),
  ('kem','6960119f',-1),
 ('kem','6960119f',0),
  ('kem','6960119pc',-1),
 ('kem','6960119pc',0),
  ('kem','6960119pcf',-1),
 ('kem','6960119pcf',0),
  ('kem','6688128',-1),
 ('kem','6688128',0),
  ('kem','6688128f',-1),
 ('kem','6688128f',0),
  ('kem','6688128pc',-1),
 ('kem','6688128pc',0),
  ('kem','6688128pcf',-1),
 ('kem','6688128pcf',0),
  ('kem','8192128',-1),
 ('kem','8192128',0),
  ('kem','8192128f',-1),
 ('kem','8192128f',0),
  ('kem','8192128pc',-1),
 ('kem','8192128pc',0),
  ('kem','8192128pcf',-1),
 ('kem','8192128pcf',0),
  ('kem','460896',-1),
 ('kem','460896',0),
  ('kem','460896f',-1),
 ('kem','460896f',0),
  ('kem','460896pc',-1),
 ('kem','460896pc',0),
  ('kem','460896pcf',-1),
 ('kem','460896pcf',0),
  ('kem','348864',-1),
 ('kem','348864',0),
  ('kem','348864f',-1),
 ('kem','348864f',0),
  ('kem','348864pc',-1),
 ('kem','348864pc',0),
  ('kem','348864pcf',-1),
 ('kem','348864pcf',0),
]

todo = [task+(offset,) for task in todo for offset in range(1 if task[0] == "sort" else 2)]
todo = [(False,)+task for task in todo]+[(True,)+task for task in todo if task[-1] == 0]

if not doit(todo):
  print('some tests failed')
  sys.exit(111)
print('full tests succeeded')
