import sys
import re

pat_go = re.compile(r"GO:\d+") # GO:000000

go_def = "go_definition.tsv"
# input_file = "/home/ytanizaw/marchantia/func_annotation/interpro/interpro_r2_tsv/r2_interpro.tsv"
# input_file = "/home/ytanizaw/marchantia/assembly_hifi/func_annotation/MpTak1_v7.1/MpTak1_v7.1.interpro.tsv"
input_file = sys.argv[1]

def read_go():
    D = {}
    for line in open(go_def):
        line = line.strip("\n")
        go_id, go_name = line.split("\t")
        D[go_id] = go_name
    return D

def extract_go_terms(input_string: str):
    return pat_go.findall(input_string)

def read_interpro(file_name):
    D = {}
    for line in open(file_name):
        cols = line.strip("\n").split("\t")
        gene_id = cols[0]
        if len(cols) < 14:
            # print("no go line", line)
            continue
        else:
            col_go = cols[13]
            if not col_go.startswith("GO:"):
                # print(line)
                continue
            else:
                go_terms = extract_go_terms(col_go)
            for go_term in go_terms: 
                D.setdefault(gene_id, set()).add(go_term)
    return D

dict_go = read_go()
dict_ipr = read_interpro(input_file)

ref_db = "GO"
for gene_id, set_go in dict_ipr.items():
    for go_id in list(set_go):
        go_name = dict_go[go_id]
    #gene_id        ref_db  gene_id description
        print("\t".join([gene_id, ref_db, go_id, go_name]))